otter

Форк
0
/
pretraining.py 
611 строк · 23.9 Кб
1
""" Main training script """
2

3
import argparse
4
import glob
5
import os
6
import random
7
import sys
8
import time
9

10
import numpy as np
11
import torch
12
import torch.nn
13
from accelerate import Accelerator
14
from tqdm import tqdm
15
from transformers import (
16
    CLIPImageProcessor,
17
    get_constant_schedule_with_warmup,
18
    get_cosine_schedule_with_warmup,
19
    get_linear_schedule_with_warmup,
20
)
21

22
import wandb
23
from otter_ai import FlamingoForConditionalGeneration, OtterForConditionalGeneration
24

25
sys.path.append("../..")
26
from pipeline.mimicit_utils.data import get_data
27
from pipeline.train.distributed import world_info_from_env
28
from pipeline.train.train_utils import AverageMeter, get_checkpoint
29

30
os.environ["TOKENIZERS_PARALLELISM"] = "false"
31

32

33
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
34
# in PyTorch 1.12 and later.
35
torch.backends.cuda.matmul.allow_tf32 = True
36

37
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
38
torch.backends.cudnn.allow_tf32 = True
39

40

41
def parse_args():
42
    parser = argparse.ArgumentParser()
43
    parser.add_argument(
44
        "--external_save_dir",
45
        type=str,
46
        default=None,
47
        help="set to save model to external path",
48
    )
49
    parser.add_argument(
50
        "--resume_from_checkpoint",
51
        action="store_true",
52
        help="Whether to resume from checkpoint, if set True, will load models from --external_save_dir",
53
    )
54
    parser.add_argument(
55
        "--delete_previous_checkpoint",
56
        action="store_true",
57
        help="delete previous checkpoint when saving new checkpoint",
58
    )
59
    parser.add_argument(
60
        "--run_name",
61
        type=str,
62
        default="otter_9b",
63
        help="used to name saving directory and wandb run",
64
    )
65
    parser.add_argument(
66
        "--mmc4_shards",
67
        type=str,
68
        help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
69
    )
70
    parser.add_argument(
71
        "--laion_shards",
72
        type=str,
73
        help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
74
    )
75
    parser.add_argument("--train_num_samples_mmc4", type=int, default=100)
76
    parser.add_argument("--train_num_samples_laion", type=int, default=100)
77
    parser.add_argument("--batch_size_mmc4", type=int, default=8)
78
    parser.add_argument("--batch_size_laion", type=int, default=8)
79
    parser.add_argument("--workers", type=int, default=8)
80
    parser.add_argument("--dataset_resampled", action="store_true")
81
    parser.add_argument(
82
        "--mmc4_textsim_threshold",
83
        default=0.32,
84
        type=float,
85
        help="threshold for filtering images in mmc4 based on image-text similarity",
86
    )
87

88
    # parser.add_argument("--use_media_placement_augmentation", action="store_true")
89
    parser.add_argument("--offline", action="store_true")
90
    parser.add_argument("--num_epochs", type=int, default=1)
91
    parser.add_argument("--logging_steps", type=int, default=100, help="log loss every n steps")
92
    parser.add_argument(
93
        "--checkpointing_steps",
94
        type=int,
95
        default=10000,
96
        help="checkpointing every n steps",
97
    )
98
    # Sum of gradient optimization batch size
99

100
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
101
    parser.add_argument(
102
        "--pretrained_model_name_or_path",
103
        type=str,
104
        help="path to huggingface model or model identifier from local path or huggingface.co",
105
        default=None,
106
    )
107
    parser.add_argument("--seed", type=int, default=42)
108
    parser.add_argument("--learning_rate", default=1e-4, type=float)
109
    parser.add_argument(
110
        "--lr_scheduler",
111
        default="constant",
112
        type=str,
113
        help="constant, linear, or cosine",
114
    )
115
    parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0)
116
    parser.add_argument("--loss_multiplier_laion", type=float, default=0.2)
117
    parser.add_argument("--warmup_steps", default=1000, type=int)
118
    parser.add_argument("--warmup_steps_ratio", default=None, type=float)
119
    parser.add_argument("--weight_decay", default=0.1, type=float)
120
    parser.add_argument(
121
        "--precision",
122
        choices=["amp_bf16", "amp_bfloat16", "bf16", "amp", "fp16", "fp32"],
123
        default="amp",
124
        help="Floating point precision.",
125
    )
126
    # distributed training args
127
    parser.add_argument(
128
        "--dist-url",
129
        default="env://",
130
        type=str,
131
        help="url used to set up distributed training",
132
    )
133
    parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
134
    parser.add_argument(
135
        "--no-set-device-rank",
136
        default=False,
137
        action="store_true",
138
        help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
139
    )
140
    # YH: Training detail
141
    parser.add_argument("--mask_lm_head", action="store_true")
142
    parser.add_argument(
143
        "--max-src-length",
144
        type=int,
145
        default=1024,
146
        help="the maximum src sequence length",
147
    )
148
    parser.add_argument(
149
        "--max-tgt-length",
150
        type=int,
151
        default=1024,
152
        help="the maximum target sequence length",
153
    )
154
    parser.add_argument("--patch-image-size", type=int, default=224)
155
    # this could potentially save 33GB of all model parameters for otter-9b, including the language and vision model.
156
    parser.add_argument("--save_hf_model", default=False, action="store_true")
157
    # wandb args
158
    parser.add_argument("--report_to_wandb", default=False, action="store_true")
159
    parser.add_argument(
160
        "--wandb_project",
161
        type=str,
162
    )
163
    parser.add_argument(
164
        "--wandb_entity",
165
        type=str,
166
    )
167
    parser.add_argument(
168
        "--save_checkpoints_to_wandb",
169
        default=False,
170
        action="store_true",
171
        help="save checkpoints to wandb",
172
    )
173
    return parser
174

175

176
def random_seed(seed=42, rank=0):
177
    torch.manual_seed(seed + rank)
178
    np.random.seed(seed + rank)
179
    random.seed(seed + rank)
180

181

182
def train_one_epoch(
183
    args,
184
    model,
185
    epoch,
186
    mmc4_loader,
187
    laion_loader,
188
    tokenizer,
189
    optimizer,
190
    lr_scheduler,
191
    device_id,
192
    accelerator,
193
    wandb,
194
):
195
    num_batches_per_epoch_laion = laion_loader.num_batches
196
    num_batches_per_epoch_mmc4 = mmc4_loader.num_batches
197

198
    assert num_batches_per_epoch_laion == num_batches_per_epoch_mmc4, "Number of batches in laion and mmc4 datasets must be the same"
199

200
    num_batches_per_epoch = num_batches_per_epoch_mmc4
201
    total_training_steps = num_batches_per_epoch * args.num_epochs
202

203
    media_token_id = tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
204
    endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
205
    answer_token_id = tokenizer("<answer>", add_special_tokens=False)["input_ids"][-1]
206

207
    model.train()
208

209
    # setup logging
210
    step_time_m = AverageMeter()  # time for one optimizer step (> 1 batch if using gradient accum)
211
    data_time_m = AverageMeter()  # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum)
212
    end = time.time()
213

214
    # loop through dataloader
215
    for num_steps, (batch_laion, batch_mmc4) in tqdm(
216
        enumerate(zip(laion_loader, mmc4_loader)),
217
        disable=args.rank != 0,
218
        total=total_training_steps,
219
        initial=(epoch * num_batches_per_epoch),
220
    ):
221
        data_time_m.update(time.time() - end)
222

223
        global_step = num_steps + epoch * num_batches_per_epoch
224
        total_losses = []
225

226
        #### LAION FORWARD PASS ####
227
        images = batch_laion[0].to(device_id, non_blocking=True).unsqueeze(1).unsqueeze(1)
228

229
        input_ids = batch_laion[1][0].to(device_id, non_blocking=True)
230
        attention_mask = batch_laion[1][1].to(device_id, non_blocking=True)
231

232
        labels = input_ids.clone()
233
        labels[labels == tokenizer.pad_token_id] = -100
234
        labels[:, 0] = -100
235
        labels[labels == media_token_id] = -100
236
        labels.to(device_id)
237

238
        with accelerator.autocast():
239
            loss_laion = model(
240
                vision_x=images,
241
                lang_x=input_ids,
242
                attention_mask=attention_mask,
243
                labels=labels,
244
            )[0]
245

246
        # model.eval()
247
        # model.text_tokenizer.padding_side = "left"
248
        # text_prompt_lang_x = model.text_tokenizer(
249
        #     [
250
        #         "<image>",
251
        #     ],
252
        #     return_tensors="pt",
253
        # )['input_ids']
254
        # outputs_debug = model.generate(
255
        #     vision_x=images.to(device_id),
256
        #     lang_x=text_prompt_lang_x.to(device_id),
257
        #     attention_mask=attention_mask.to(device_id),
258
        #     max_length=256,
259
        # )
260

261
        # print(model.text_tokenizer.batch_decode(outputs_debug))
262
        # print(model.text_tokenizer.batch_decode(input_ids))
263
        # model.train()
264

265
        #### LAION BACKWARD ####
266
        accelerator.backward(args.loss_multiplier_laion * loss_laion)
267
        total_losses.append(args.loss_multiplier_laion * loss_laion)
268

269
        #### MMC4 FORWARD PASS ####
270
        images = batch_mmc4[0].to(device_id, non_blocking=True).unsqueeze(2)
271
        input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1)
272
        attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1)
273

274
        # NOTE: irena: expected shape of clip_text_input_ids / attention_mask is (N, I, max_seq_len)
275
        labels = input_ids.clone()
276
        labels[labels == tokenizer.pad_token_id] = -100
277
        labels[:, 0] = -100
278

279
        for i in range(labels.shape[0]):
280
            # remove loss for any token before the first <image> token
281
            label_idx = 0
282
            while label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id:
283
                labels[i][label_idx] = -100
284
                label_idx += 1
285

286
            # get index of all endofchunk tokens in the sequence
287
            endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0]
288
            for endofchunk_idx in endofchunk_idxs:
289
                token_idx = endofchunk_idx + 1
290
                while token_idx < labels.shape[1] and labels[i][token_idx] != media_token_id:
291
                    labels[i][token_idx] = -100
292
                    token_idx += 1
293

294
        labels[labels == media_token_id] = -100
295
        labels.to(device_id)
296

297
        # with accelerator.accumulate(model):
298
        with accelerator.autocast():
299
            loss_mmc4 = model(
300
                vision_x=images,
301
                lang_x=input_ids,
302
                attention_mask=attention_mask,
303
                labels=labels,
304
            )[0]
305

306
        # model.text_tokenizer.padding_side = "left"
307
        # outputs_debug = model.generate(
308
        #     vision_x=images.to(device_id),
309
        #     lang_x=input_ids.to(device_id),
310
        #     attention_mask=attention_mask.to(device_id),
311
        #     max_length=256,
312
        # )
313

314
        # print(model.text_tokenizer.batch_decode(outputs_debug))
315
        # print(model.text_tokenizer.batch_decode(input_ids))
316

317
        #### MMC4 BACKWARD ####
318
        accelerator.backward(args.loss_multiplier_mmc4 * loss_mmc4)
319
        total_losses.append(args.loss_multiplier_mmc4 * loss_mmc4)
320
        #### Collect MMC4/LAION Loss Info ####
321
        total_loss_sum = sum(total_losses)
322
        mean_loss = total_loss_sum / len(total_losses)
323
        # accelerator.backward(total_loss_sum.to(device_id))
324

325
        def mask_embedding(m):
326
            if m.weight.requires_grad:
327
                zero_mask = torch.zeros_like(m.weight.grad)
328
                # zero_mask[answer_token_id] = torch.ones_like(zero_mask[answer_token_id])
329
                zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
330
                zero_mask[endofchunk_token_id] = torch.ones_like(zero_mask[endofchunk_token_id])
331
                m.weight.grad = m.weight.grad * zero_mask
332

333
        if args.mask_lm_head:
334
            unwrapped_model = accelerator.unwrap_model(model)
335
            if unwrapped_model.lang_encoder.__class__.__name__ == "MPTForCausalLM":
336
                unwrapped_model.lang_encoder.transformer.wte.apply(mask_embedding)
337
            elif unwrapped_model.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
338
                unwrapped_model.lang_encoder.model.embed_tokens.apply(mask_embedding)
339
                unwrapped_model.lang_encoder.lm_head.apply(mask_embedding)
340

341
        if accelerator.sync_gradients:
342
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
343

344
        optimizer.step()
345
        lr_scheduler.step()
346
        optimizer.zero_grad()
347

348
        # step time and reset end outside of rank 0
349
        step_time_m.update(time.time() - end)
350
        end = time.time()
351

352
        if accelerator.sync_gradients:
353
            if args.rank == 0 and args.report_to_wandb:
354
                # compute within rank 0
355
                mmc4_samples_per_second = args.gradient_accumulation_steps * args.batch_size_mmc4 * args.world_size / step_time_m.val
356
                mmc4_samples_per_second_per_gpu = args.gradient_accumulation_steps * args.batch_size_mmc4 / step_time_m.val
357
                laion_samples_per_second = args.gradient_accumulation_steps * args.batch_size_laion * args.world_size / step_time_m.val
358
                laion_samples_per_second_per_gpu = args.gradient_accumulation_steps * args.batch_size_laion / step_time_m.val
359
                wandb.log(
360
                    {
361
                        "data_time": data_time_m.avg,
362
                        "step_time": step_time_m.avg,
363
                        "mmc4_samples_per_second": mmc4_samples_per_second,
364
                        "mmc4_samples_per_second_per_gpu": mmc4_samples_per_second_per_gpu,
365
                        "laion_samples_per_second": laion_samples_per_second,
366
                        "laion_samples_per_second_per_gpu": laion_samples_per_second_per_gpu,
367
                        "lr": optimizer.param_groups[0]["lr"],
368
                    },
369
                    commit=False,
370
                )
371
                step_time_m.reset()
372
                data_time_m.reset()
373

374
                wandb.log(
375
                    {
376
                        "mmc4_loss": loss_mmc4.item(),
377
                        "laion_loss": loss_laion.item(),
378
                        "mean_loss": mean_loss.item(),
379
                        "global_step": global_step // args.gradient_accumulation_steps,
380
                    },
381
                    commit=True,
382
                )
383

384
        # Log loss to console
385
        if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0:
386
            print(f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Mean Loss: {mean_loss.item():.3f}")
387
        # Add a process on saving checkpoints during pretraining
388
        if ((num_steps + 1) % args.checkpointing_steps == 0) and args.rank == 0:
389
            if not os.path.exists(args.external_save_dir):
390
                os.makedirs(args.external_save_dir)
391

392
            unwrapped_model = accelerator.unwrap_model(model)
393
            checkpoint_dict = {
394
                "epoch": epoch,
395
                "model_state_dict": get_checkpoint(unwrapped_model),
396
                "optimizer_state_dict": optimizer.state_dict(),
397
                "lr_scheduler_state_dict": lr_scheduler.state_dict(),
398
            }
399
            print(f"Saving checkpoint to {args.external_save_dir}/checkpoint_steps{num_steps + 1}.pt")
400
            accelerator.save(
401
                checkpoint_dict,
402
                f"{args.external_save_dir}/checkpoint_steps{num_steps + 1}.pt",
403
            )
404
            # save the config
405
            print(f"Saving config to {args.external_save_dir}/config.json")
406
            unwrapped_model.config.save_pretrained(args.external_save_dir)
407
            if args.delete_previous_checkpoint:
408
                if (num_steps + 1) // args.checkpointing_steps >= 2:
409
                    previous_checkpoint_path = f"{args.external_save_dir}/checkpoint_steps{num_steps + 1 - args.checkpointing_steps}.pt"
410
                    if os.path.exists(previous_checkpoint_path):
411
                        os.remove(previous_checkpoint_path)
412

413

414
def main():
415
    parser = parse_args()
416
    # TODO: remove additional data args, all args would be processed in above parser
417
    # parser = add_data_args(parser)
418
    args = parser.parse_args()
419

420
    if args.save_checkpoints_to_wandb and not args.report_to_wandb:
421
        raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")
422

423
    if args.offline:
424
        os.environ["WANDB_MODE"] = "offline"
425
        os.environ["TRANSFORMERS_OFFLINE"] = "1"
426

427
    args.local_rank, args.rank, args.world_size = world_info_from_env()
428
    accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
429

430
    device_id = accelerator.device
431

432
    random_seed(args.seed)
433
    if args.pretrained_model_name_or_path is not None:
434
        accelerator.print(f"Loading pretrained model from {args.pretrained_model_name_or_path}")
435
        if "otter" in args.run_name.lower():
436
            model = OtterForConditionalGeneration.from_pretrained(
437
                args.pretrained_model_name_or_path,
438
                device_map="auto",
439
                local_files_only=args.offline,
440
            )
441
        elif "flamingo" in args.run_name.lower():
442
            if accelerator.num_processes > 1:
443
                model = FlamingoForConditionalGeneration.from_pretrained(
444
                    args.pretrained_model_name_or_path,
445
                    device_map={"": device_id},
446
                    local_files_only=args.offline,
447
                )
448
            else:
449
                model = FlamingoForConditionalGeneration.from_pretrained(
450
                    args.pretrained_model_name_or_path,
451
                    device_map="auto",
452
                    local_files_only=args.offline,
453
                )
454
            model.text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>", "<answer>"]})
455
    else:
456
        model = None
457

458
    accelerator.wait_for_everyone()
459

460
    if model.lang_encoder.__class__.__name__ != "MPTForCausalLM":
461
        model.lang_encoder.resize_token_embeddings(len(model.text_tokenizer))
462

463
    args.tokenizer = model.text_tokenizer
464
    tokenizer = model.text_tokenizer
465
    random_seed(args.seed, args.rank)
466

467
    image_processor = CLIPImageProcessor()
468

469
    mmc4_dataset = get_data(args, image_processor, tokenizer, "mmc4")
470
    laion_dataset = get_data(args, image_processor, tokenizer, "laion")
471

472
    def get_grouped_params(model):
473
        params_with_wd, params_without_wd = [], []
474

475
        def apply_decay(x):
476
            return "gated_cross_attn_layer" in x and "ff_gate" not in x and "attn_gate" not in x and "norm" not in x and "bias" not in x
477

478
        for n, p in model.named_parameters():
479
            # if p.requires_grad:
480
            if apply_decay(n):
481
                params_with_wd.append(p)
482
            else:
483
                params_without_wd.append(p)
484

485
        return [
486
            {"params": params_with_wd, "weight_decay": args.weight_decay},
487
            {"params": params_without_wd, "weight_decay": 0.0},
488
        ]
489

490
    # total_training_steps = ((args.train_num_samples_mmc4) // (args.batch_size_mmc4 * args.world_size)) * args.num_epochs
491
    total_training_steps = mmc4_dataset.dataloader.num_batches * args.num_epochs
492

493
    resume_from_epoch = 0
494
    # check if a checkpoint exists for this run
495
    args.external_save_dir = os.path.join(args.external_save_dir, args.run_name) if args.external_save_dir else args.run_name
496
    if os.path.exists(f"{args.external_save_dir}") and args.resume_from_checkpoint is True:
497
        checkpoint_list = glob.glob(f"{args.external_save_dir}/checkpoint_*.pt")
498
        if len(checkpoint_list) == 0:
499
            print(f"Found no checkpoints for run {args.external_save_dir}.")
500
        else:
501
            resume_from_checkpoint_path = sorted(checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1]
502
            print(f"Found checkpoint {resume_from_checkpoint_path} for run {args.external_save_dir}.")
503

504
        if args.rank == 0:
505
            print(f"Loading checkpoint from {resume_from_checkpoint_path}")
506
        checkpoint = torch.load(resume_from_checkpoint_path, map_location="cpu")
507
        model.load_state_dict(checkpoint["model_state_dict"], False)
508
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
509
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
510
        resume_from_epoch = checkpoint["epoch"] + 1
511

512
    optimizer = torch.optim.AdamW(get_grouped_params(model), lr=args.learning_rate)
513

514
    if args.rank == 0:
515
        print(f"Total training steps: {total_training_steps}")
516

517
    args.warmup_steps = total_training_steps * args.warmup_steps_ratio if args.warmup_steps_ratio is not None else args.warmup_steps
518

519
    if args.lr_scheduler == "linear":
520
        lr_scheduler = get_linear_schedule_with_warmup(
521
            optimizer,
522
            num_warmup_steps=args.warmup_steps // args.gradient_accumulation_steps,
523
            num_training_steps=total_training_steps // args.gradient_accumulation_steps,
524
        )
525
    elif args.lr_scheduler == "cosine":
526
        lr_scheduler = get_cosine_schedule_with_warmup(
527
            optimizer,
528
            num_warmup_steps=args.warmup_steps // args.gradient_accumulation_steps,
529
            num_training_steps=total_training_steps // args.gradient_accumulation_steps,
530
        )
531
    else:
532
        lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
533

534
    if args.rank == 0 and args.report_to_wandb:
535
        wandb.init(
536
            project=args.wandb_project,
537
            entity=args.wandb_entity,
538
            name=args.run_name,
539
            config=vars(args),
540
        )
541

542
    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
543

544
    # YH: hardcode for ddp, reason is related to "split_batch" in accelerator. Currently just fix this bug, need to dig further.
545
    if accelerator.num_processes > 1:
546
        lr_scheduler.split_batches = True
547

548
    model.train()
549

550
    for epoch in range(resume_from_epoch, args.num_epochs):
551
        laion_dataset.set_epoch(epoch)
552
        laion_loader = laion_dataset.dataloader
553

554
        mmc4_dataset.set_epoch(epoch)
555
        mmc4_loader = mmc4_dataset.dataloader
556

557
        train_one_epoch(
558
            args=args,
559
            model=model,
560
            epoch=epoch,
561
            tokenizer=tokenizer,
562
            optimizer=optimizer,
563
            lr_scheduler=lr_scheduler,
564
            mmc4_loader=mmc4_loader,
565
            laion_loader=laion_loader,
566
            accelerator=accelerator,
567
            device_id=device_id,
568
            wandb=wandb,
569
        )
570
        if args.rank == 0:
571
            if not os.path.exists(args.external_save_dir):
572
                os.makedirs(args.external_save_dir)
573

574
            unwrapped_model = accelerator.unwrap_model(model)
575
            checkpoint_dict = {
576
                "epoch": epoch,
577
                "model_state_dict": get_checkpoint(unwrapped_model),
578
                "optimizer_state_dict": optimizer.state_dict(),
579
                "lr_scheduler_state_dict": lr_scheduler.state_dict(),
580
            }
581
            print(f"Saving checkpoint to {args.external_save_dir}/checkpoint_epoch{epoch}.pt")
582
            accelerator.save(checkpoint_dict, f"{args.external_save_dir}/checkpoint_epoch{epoch}.pt")
583
            # save the config
584
            unwrapped_model.config.save_pretrained(args.external_save_dir)
585
            if args.delete_previous_checkpoint:
586
                if epoch > 0:
587
                    os.remove(f"{args.external_save_dir}/checkpoint_epoch{epoch-1}.pt")
588

589
        accelerator.wait_for_everyone()
590

591
    accelerator.wait_for_everyone()
592
    if args.rank == 0:
593
        if not os.path.exists(args.external_save_dir):
594
            os.makedirs(args.external_save_dir)
595

596
        unwrapped_model = accelerator.unwrap_model(model)
597
        accelerator.save(
598
            get_checkpoint(model=unwrapped_model),
599
            f"{args.external_save_dir}/final_weights.pt",
600
        )
601
        # save the config
602
        unwrapped_model.config.save_pretrained(args.external_save_dir)
603

604
        if args.report_to_wandb and args.save_checkpoints_to_wandb:
605
            wandb.save(f"{args.external_save_dir}/final_weights.pt")
606
        if args.save_hf_model:
607
            unwrapped_model.save_pretrained(f"{args.external_save_dir}")
608

609

610
if __name__ == "__main__":
611
    main()
612

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.