transformers

Форк
0
810 строк · 29.9 Кб
1
#!/usr/bin/env python
2
# coding=utf-8
3
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15

16
import argparse
17
import logging
18
import math
19
import os
20
import warnings
21
from pathlib import Path
22

23
import datasets
24
import numpy as np
25
import torch
26
from accelerate import Accelerator, DistributedType
27
from accelerate.utils import set_seed
28
from datasets import load_dataset
29
from huggingface_hub import Repository, create_repo
30
from torch.utils.data import DataLoader
31
from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor
32
from tqdm.auto import tqdm
33

34
import transformers
35
from transformers import (
36
    CONFIG_MAPPING,
37
    IMAGE_PROCESSOR_MAPPING,
38
    MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
39
    AutoConfig,
40
    AutoImageProcessor,
41
    AutoModelForMaskedImageModeling,
42
    SchedulerType,
43
    get_scheduler,
44
)
45
from transformers.utils import check_min_version, send_example_telemetry
46
from transformers.utils.versions import require_version
47

48

49
""" Pre-training a 🤗 Transformers model for simple masked image modeling (SimMIM)
50
without using HuggingFace Trainer.
51
Any model supported by the AutoModelForMaskedImageModeling API can be used.
52
"""
53

54
logger = logging.getLogger(__name__)
55

56
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
57
check_min_version("4.39.0.dev0")
58

59
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
60

61
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING.keys())
62
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
63

64

65
def parse_args():
66
    parser = argparse.ArgumentParser(
67
        description="Finetune a transformers model on a simple Masked Image Modeling task"
68
    )
69
    parser.add_argument(
70
        "--dataset_name",
71
        type=str,
72
        default="cifar10",
73
        help="Name of a dataset from the datasets package",
74
    )
75
    parser.add_argument(
76
        "--dataset_config_name",
77
        type=str,
78
        default=None,
79
        help="The configuration name of the dataset to use (via the datasets library).",
80
    )
81
    parser.add_argument(
82
        "--image_column_name",
83
        type=str,
84
        default=None,
85
        help="The column name of the images in the files. If not set, will try to use 'image' or 'img'.",
86
    )
87
    parser.add_argument(
88
        "--train_dir",
89
        type=str,
90
        default=None,
91
        help="A folder containing the training data.",
92
    )
93
    parser.add_argument(
94
        "--validation_dir",
95
        type=None,
96
        default=None,
97
        help="A folder containing the validation data.",
98
    )
99
    parser.add_argument(
100
        "--train_val_split",
101
        type=float,
102
        default=0.15,
103
        help="Percent to split off of train for validation.",
104
    )
105
    parser.add_argument(
106
        "--mask_patch_size",
107
        type=int,
108
        default=32,
109
        help="The size of the square patches to use for masking.",
110
    )
111
    parser.add_argument(
112
        "--mask_ratio",
113
        type=float,
114
        default=0.6,
115
        help="Percentage of patches to mask.",
116
    )
117
    parser.add_argument(
118
        "--max_train_samples",
119
        type=int,
120
        default=None,
121
        help=(
122
            "For debugging purposes or quicker training, truncate the number of training examples to this "
123
            "value if set."
124
        ),
125
    )
126
    parser.add_argument(
127
        "--max_eval_samples",
128
        type=int,
129
        default=None,
130
        help=(
131
            "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
132
            "value if set."
133
        ),
134
    )
135
    parser.add_argument(
136
        "--model_name_or_path",
137
        type=str,
138
        default=None,
139
        help=(
140
            "The model checkpoint for weights initialization. Can be a local path to a pytorch_model.bin or a "
141
            "checkpoint identifier on the hub. "
142
            "Don't set if you want to train a model from scratch."
143
        ),
144
    )
145
    parser.add_argument(
146
        "--model_type",
147
        type=str,
148
        default=None,
149
        help="If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES),
150
    )
151
    parser.add_argument(
152
        "--config_name_or_path",
153
        type=str,
154
        default=None,
155
        help="Pretrained config name or path if not the same as model_name",
156
    )
157
    parser.add_argument(
158
        "--config_overrides",
159
        type=str,
160
        default=None,
161
        help=(
162
            "Override some existing default config settings when a model is trained from scratch. Example: "
163
            "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
164
        ),
165
    )
166
    parser.add_argument(
167
        "--cache_dir",
168
        type=str,
169
        default=None,
170
        help="Where do you want to store (cache) the pretrained models/datasets downloaded from the hub",
171
    )
172
    parser.add_argument(
173
        "--model_revision",
174
        type=str,
175
        default="main",
176
        help="The specific model version to use (can be a branch name, tag name or commit id).",
177
    )
178
    parser.add_argument(
179
        "--gradient_accumulation_steps",
180
        type=int,
181
        default=1,
182
        help="Number of updates steps to accumulate before performing a backward/update pass.",
183
    )
184
    parser.add_argument(
185
        "--image_processor_name",
186
        type=str,
187
        default=None,
188
        help="Name or path of preprocessor config.",
189
    )
190
    parser.add_argument(
191
        "--token",
192
        type=str,
193
        default=None,
194
        help=(
195
            "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
196
            "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
197
        ),
198
    )
199
    parser.add_argument(
200
        "--use_auth_token",
201
        type=bool,
202
        default=None,
203
        help="The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
204
    )
205
    parser.add_argument(
206
        "--trust_remote_code",
207
        type=bool,
208
        default=False,
209
        help=(
210
            "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
211
            "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
212
            "execute code present on the Hub on your local machine."
213
        ),
214
    )
215
    parser.add_argument(
216
        "--image_size",
217
        type=int,
218
        default=None,
219
        help="The size (resolution) of each image. If not specified, will use `image_size` of the configuration.",
220
    )
221
    parser.add_argument(
222
        "--patch_size",
223
        type=int,
224
        default=None,
225
        help="The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration.",
226
    )
227
    parser.add_argument(
228
        "--encoder_stride",
229
        type=int,
230
        default=None,
231
        help={"help": "Stride to use for the encoder."},
232
    )
233
    parser.add_argument(
234
        "--push_to_hub",
235
        action="store_true",
236
        help="Whether or not to push the model to the Hub.",
237
    )
238
    parser.add_argument(
239
        "--with_tracking",
240
        action="store_true",
241
        help="Whether to enable experiment trackers for logging.",
242
    )
243
    parser.add_argument(
244
        "--report_to",
245
        type=str,
246
        default="all",
247
        help=(
248
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
249
            ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. '
250
            "Only applicable when `--with_tracking` is passed."
251
        ),
252
    )
253
    parser.add_argument(
254
        "--seed",
255
        type=int,
256
        default=None,
257
        help="A seed for reproducible training.",
258
    )
259
    parser.add_argument(
260
        "--per_device_train_batch_size",
261
        type=int,
262
        default=8,
263
        help="Batch size (per device) for the training dataloader.",
264
    )
265
    parser.add_argument(
266
        "--learning_rate",
267
        type=float,
268
        default=5e-5,
269
        help="The initial learning rate for [`AdamW`] optimizer.",
270
    )
271
    parser.add_argument(
272
        "--weight_decay",
273
        type=float,
274
        default=0.0,
275
        help="Weight decay to use.",
276
    )
277
    parser.add_argument(
278
        "--num_train_epochs",
279
        type=float,
280
        default=3.0,
281
        help="Total number of training epochs to perform (if not an integer, will perform the decimal part percents of the last epoch before stopping training).",
282
    )
283
    parser.add_argument(
284
        "--max_train_steps",
285
        type=int,
286
        default=None,
287
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
288
    )
289
    parser.add_argument(
290
        "--lr_scheduler_type",
291
        type=SchedulerType,
292
        default="linear",
293
        help="The scheduler type to use.",
294
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
295
    )
296
    parser.add_argument(
297
        "--num_warmup_steps",
298
        type=int,
299
        default=0,
300
        help="Number of steps for the warmup in the lr scheduler.",
301
    )
302
    parser.add_argument(
303
        "--checkpointing_steps",
304
        type=str,
305
        default=None,
306
        help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
307
    )
308
    parser.add_argument(
309
        "--resume_from_checkpoint",
310
        type=str,
311
        default=None,
312
        help="If the training should continue from a checkpoint folder.",
313
    )
314
    parser.add_argument(
315
        "--per_device_eval_batch_size",
316
        type=int,
317
        default=8,
318
        help="Batch size (per device) for the evaluation dataloader.",
319
    )
320
    parser.add_argument(
321
        "--output_dir",
322
        type=str,
323
        default=None,
324
        help="Where to store the final model.",
325
    )
326
    args = parser.parse_args()
327

328
    # Sanity checks
329
    data_files = {}
330
    if args.train_dir is not None:
331
        data_files["train"] = args.train_dir
332
    if args.validation_dir is not None:
333
        data_files["val"] = args.validation_dir
334
    args.data_files = data_files if data_files else None
335

336
    if args.push_to_hub:
337
        assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
338

339
    return args
340

341

342
class MaskGenerator:
343
    """
344
    A class to generate boolean masks for the pretraining task.
345

346
    A mask is a 1D tensor of shape (model_patch_size**2,) where the value is either 0 or 1,
347
    where 1 indicates "masked".
348
    """
349

350
    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
351
        self.input_size = input_size
352
        self.mask_patch_size = mask_patch_size
353
        self.model_patch_size = model_patch_size
354
        self.mask_ratio = mask_ratio
355

356
        if self.input_size % self.mask_patch_size != 0:
357
            raise ValueError("Input size must be divisible by mask patch size")
358
        if self.mask_patch_size % self.model_patch_size != 0:
359
            raise ValueError("Mask patch size must be divisible by model patch size")
360

361
        self.rand_size = self.input_size // self.mask_patch_size
362
        self.scale = self.mask_patch_size // self.model_patch_size
363

364
        self.token_count = self.rand_size**2
365
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
366

367
    def __call__(self):
368
        mask_idx = np.random.permutation(self.token_count)[: self.mask_count]
369
        mask = np.zeros(self.token_count, dtype=int)
370
        mask[mask_idx] = 1
371

372
        mask = mask.reshape((self.rand_size, self.rand_size))
373
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
374

375
        return torch.tensor(mask.flatten())
376

377

378
def collate_fn(examples):
379
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
380
    mask = torch.stack([example["mask"] for example in examples])
381
    return {"pixel_values": pixel_values, "bool_masked_pos": mask}
382

383

384
def main():
385
    args = parse_args()
386

387
    if args.use_auth_token is not None:
388
        warnings.warn(
389
            "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
390
            FutureWarning,
391
        )
392
        if args.token is not None:
393
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
394
        args.token = args.use_auth_token
395

396
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
397
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
398
    send_example_telemetry("run_mim_no_trainer", args)
399

400
    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
401
    # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
402
    # in the environment
403
    accelerator_log_kwargs = {}
404

405
    if args.with_tracking:
406
        accelerator_log_kwargs["log_with"] = args.report_to
407
        accelerator_log_kwargs["project_dir"] = args.output_dir
408

409
    accelerator = Accelerator(
410
        gradient_accumulation_steps=args.gradient_accumulation_steps,
411
        **accelerator_log_kwargs,
412
    )
413

414
    # Make one log on every process with the configuration for debugging.
415
    logging.basicConfig(
416
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
417
        datefmt="%m/%d/%Y %H:%M:%S",
418
        level=logging.INFO,
419
    )
420
    logger.info(accelerator.state)
421
    if accelerator.is_local_main_process:
422
        datasets.utils.logging.set_verbosity_warning()
423
        transformers.utils.logging.set_verbosity_info()
424
    else:
425
        datasets.utils.logging.set_verbosity_error()
426
        transformers.utils.logging.set_verbosity_error()
427

428
    # If passed along, set the training seed now.
429
    if args.seed is not None:
430
        set_seed(args.seed)
431

432
    # Handle the repository creation
433
    if accelerator.is_main_process:
434
        if args.push_to_hub:
435
            # Retrieve of infer repo_name
436
            repo_name = args.hub_model_id
437
            if repo_name is None:
438
                repo_name = Path(args.output_dir).absolute().name
439
            # Create repo and retrieve repo_id
440
            repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
441
            # Clone repo locally
442
            repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
443

444
            with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
445
                if "step_*" not in gitignore:
446
                    gitignore.write("step_*\n")
447
                if "epoch_*" not in gitignore:
448
                    gitignore.write("epoch_*\n")
449
        elif args.output_dir is not None:
450
            os.makedirs(args.output_dir, exist_ok=True)
451
    accelerator.wait_for_everyone()
452

453
    # Initialize our dataset.
454
    ds = load_dataset(
455
        args.dataset_name,
456
        args.dataset_config_name,
457
        data_files=args.data_files,
458
        cache_dir=args.cache_dir,
459
        token=args.token,
460
    )
461

462
    # If we don't have a validation split, split off a percentage of train as validation.
463
    args.train_val_split = None if "validation" in ds.keys() else args.train_val_split
464
    if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
465
        split = ds["train"].train_test_split(args.train_val_split)
466
        ds["train"] = split["train"]
467
        ds["validation"] = split["test"]
468

469
    # Create config
470
    # Distributed training:
471
    # The .from_pretrained methods guarantee that only one local process can concurrently
472
    # download model & vocab.
473
    config_kwargs = {
474
        "cache_dir": args.cache_dir,
475
        "revision": args.model_revision,
476
        "token": args.token,
477
        "trust_remote_code": args.trust_remote_code,
478
    }
479
    if args.config_name_or_path:
480
        config = AutoConfig.from_pretrained(args.config_name_or_path, **config_kwargs)
481
    elif args.model_name_or_path:
482
        config = AutoConfig.from_pretrained(args.model_name_or_path, **config_kwargs)
483
    else:
484
        config = CONFIG_MAPPING[args.model_type]()
485
        logger.warning("You are instantiating a new config instance from scratch.")
486
        if args.config_overrides is not None:
487
            logger.info(f"Overriding config: {args.config_overrides}")
488
            config.update_from_string(args.config_overrides)
489
            logger.info(f"New config: {config}")
490

491
    # make sure the decoder_type is "simmim" (only relevant for BEiT)
492
    if hasattr(config, "decoder_type"):
493
        config.decoder_type = "simmim"
494

495
    # adapt config
496
    args.image_size = args.image_size if args.image_size is not None else config.image_size
497
    args.patch_size = args.patch_size if args.patch_size is not None else config.patch_size
498
    args.encoder_stride = args.encoder_stride if args.encoder_stride is not None else config.encoder_stride
499

500
    config.update(
501
        {
502
            "image_size": args.image_size,
503
            "patch_size": args.patch_size,
504
            "encoder_stride": args.encoder_stride,
505
        }
506
    )
507

508
    # create image processor
509
    if args.image_processor_name:
510
        image_processor = AutoImageProcessor.from_pretrained(args.image_processor_name, **config_kwargs)
511
    elif args.model_name_or_path:
512
        image_processor = AutoImageProcessor.from_pretrained(args.model_name_or_path, **config_kwargs)
513
    else:
514
        IMAGE_PROCESSOR_TYPES = {
515
            conf.model_type: image_processor_class for conf, image_processor_class in IMAGE_PROCESSOR_MAPPING.items()
516
        }
517
        image_processor = IMAGE_PROCESSOR_TYPES[args.model_type]()
518

519
    # create model
520
    if args.model_name_or_path:
521
        model = AutoModelForMaskedImageModeling.from_pretrained(
522
            args.model_name_or_path,
523
            from_tf=bool(".ckpt" in args.model_name_or_path),
524
            config=config,
525
            cache_dir=args.cache_dir,
526
            revision=args.model_revision,
527
            token=args.token,
528
            trust_remote_code=args.trust_remote_code,
529
        )
530
    else:
531
        logger.info("Training new model from scratch")
532
        model = AutoModelForMaskedImageModeling.from_config(
533
            config,
534
            token=args.token,
535
            trust_remote_code=args.trust_remote_code,
536
        )
537

538
    column_names = ds["train"].column_names
539

540
    if args.image_column_name is not None:
541
        image_column_name = args.image_column_name
542
    elif "image" in column_names:
543
        image_column_name = "image"
544
    elif "img" in column_names:
545
        image_column_name = "img"
546
    else:
547
        image_column_name = column_names[0]
548

549
    # transformations as done in original SimMIM paper
550
    # source: https://github.com/microsoft/SimMIM/blob/main/data/data_simmim.py
551
    transforms = Compose(
552
        [
553
            Lambda(lambda img: img.convert("RGB")),
554
            RandomResizedCrop(args.image_size, scale=(0.67, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
555
            RandomHorizontalFlip(),
556
            ToTensor(),
557
            Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
558
        ]
559
    )
560

561
    # create mask generator
562
    mask_generator = MaskGenerator(
563
        input_size=args.image_size,
564
        mask_patch_size=args.mask_patch_size,
565
        model_patch_size=args.patch_size,
566
        mask_ratio=args.mask_ratio,
567
    )
568

569
    def preprocess_images(examples):
570
        """Preprocess a batch of images by applying transforms + creating a corresponding mask, indicating
571
        which patches to mask."""
572

573
        examples["pixel_values"] = [transforms(image) for image in examples[image_column_name]]
574
        examples["mask"] = [mask_generator() for i in range(len(examples[image_column_name]))]
575

576
        return examples
577

578
    if args.max_train_samples is not None:
579
        ds["train"] = ds["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
580
    # Set the training transforms
581
    ds["train"].set_transform(preprocess_images)
582

583
    if args.max_eval_samples is not None:
584
        ds["validation"] = ds["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples))
585
    # Set the validation transforms
586
    ds["validation"].set_transform(preprocess_images)
587

588
    # DataLoaders creation:
589
    train_dataloader = DataLoader(
590
        ds["train"],
591
        shuffle=True,
592
        collate_fn=collate_fn,
593
        batch_size=args.per_device_train_batch_size,
594
    )
595
    eval_dataloader = DataLoader(
596
        ds["validation"],
597
        collate_fn=collate_fn,
598
        batch_size=args.per_device_eval_batch_size,
599
    )
600

601
    # Optimizer
602
    # Split weights in two groups, one with weight decay and the other not.
603
    no_decay = ["bias", "LayerNorm.weight"]
604
    optimizer_grouped_parameters = [
605
        {
606
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
607
            "weight_decay": args.weight_decay,
608
        },
609
        {
610
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
611
            "weight_decay": 0.0,
612
        },
613
    ]
614
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
615

616
    # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
617
    # shorter in multiprocess)
618

619
    # Scheduler and math around the number of training steps.
620
    overrode_max_train_steps = False
621
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
622
    if args.max_train_steps is None:
623
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
624
        overrode_max_train_steps = True
625

626
    lr_scheduler = get_scheduler(
627
        name=args.lr_scheduler_type,
628
        optimizer=optimizer,
629
        num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
630
        num_training_steps=args.max_train_steps
631
        if overrode_max_train_steps
632
        else args.max_train_steps * accelerator.num_processes,
633
    )
634

635
    # Prepare everything with our `accelerator`.
636
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
637
        model,
638
        optimizer,
639
        train_dataloader,
640
        eval_dataloader,
641
        lr_scheduler,
642
    )
643

644
    # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
645
    if accelerator.distributed_type == DistributedType.TPU:
646
        model.tie_weights()
647

648
    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
649
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
650
    if overrode_max_train_steps:
651
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
652
    # Afterwards we recalculate our number of training epochs
653
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
654

655
    # Figure out how many steps we should save the Accelerator states
656
    checkpointing_steps = args.checkpointing_steps
657
    if checkpointing_steps is not None and checkpointing_steps.isdigit():
658
        checkpointing_steps = int(checkpointing_steps)
659

660
    # We need to initialize the trackers we use, and also store our configuration.
661
    # The trackers initializes automatically on the main process.
662
    if args.with_tracking:
663
        experiment_config = vars(args)
664
        # TensorBoard cannot log Enums, need the raw value
665
        experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
666
        accelerator.init_trackers("mim_no_trainer", experiment_config)
667

668
    # Train!
669
    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
670

671
    logger.info("***** Running training *****")
672
    logger.info(f"  Num examples = {len(ds['train'])}")
673
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
674
    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
675
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
676
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
677
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
678
    # Only show the progress bar once on each machine.
679
    progress_bar = tqdm(range(int(args.max_train_steps)), disable=not accelerator.is_local_main_process)
680
    completed_steps = 0
681
    starting_epoch = 0
682

683
    # Potentially load in the weights and states from a previous save
684
    if args.resume_from_checkpoint:
685
        if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
686
            checkpoint_path = args.resume_from_checkpoint
687
            path = os.path.basename(args.resume_from_checkpoint)
688
        else:
689
            # Get the most recent checkpoint
690
            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
691
            dirs.sort(key=os.path.getctime)
692
            path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last
693
            checkpoint_path = path
694
            path = os.path.basename(checkpoint_path)
695

696
        accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
697
        accelerator.load_state(checkpoint_path)
698
        # Extract `epoch_{i}` or `step_{i}`
699
        training_difference = os.path.splitext(path)[0]
700

701
        if "epoch" in training_difference:
702
            starting_epoch = int(training_difference.replace("epoch_", "")) + 1
703
            resume_step = None
704
            completed_steps = starting_epoch * num_update_steps_per_epoch
705
        else:
706
            # need to multiply `gradient_accumulation_steps` to reflect real steps
707
            resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
708
            starting_epoch = resume_step // len(train_dataloader)
709
            completed_steps = resume_step // args.gradient_accumulation_steps
710
            resume_step -= starting_epoch * len(train_dataloader)
711

712
    # update the progress_bar if load from checkpoint
713
    progress_bar.update(completed_steps)
714

715
    for epoch in range(starting_epoch, args.num_train_epochs):
716
        model.train()
717
        if args.with_tracking:
718
            total_loss = 0
719
        if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
720
            # We skip the first `n` batches in the dataloader when resuming from a checkpoint
721
            active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
722
        else:
723
            active_dataloader = train_dataloader
724
        for step, batch in enumerate(active_dataloader):
725
            with accelerator.accumulate(model):
726
                outputs = model(**batch)
727
                loss = outputs.loss
728
                # We keep track of the loss at each epoch
729
                if args.with_tracking:
730
                    total_loss += loss.detach().float()
731
                accelerator.backward(loss)
732
                optimizer.step()
733
                lr_scheduler.step()
734
                optimizer.zero_grad()
735

736
            # Checks if the accelerator has performed an optimization step behind the scenes
737
            if accelerator.sync_gradients:
738
                progress_bar.update(1)
739
                completed_steps += 1
740

741
            if isinstance(checkpointing_steps, int):
742
                if completed_steps % checkpointing_steps == 0:
743
                    output_dir = f"step_{completed_steps}"
744
                    if args.output_dir is not None:
745
                        output_dir = os.path.join(args.output_dir, output_dir)
746
                    accelerator.save_state(output_dir)
747

748
            if completed_steps >= args.max_train_steps:
749
                break
750

751
        model.eval()
752
        losses = []
753
        for step, batch in enumerate(eval_dataloader):
754
            with torch.no_grad():
755
                outputs = model(**batch)
756

757
            loss = outputs.loss
758
            losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))
759

760
        losses = torch.cat(losses)
761
        eval_loss = torch.mean(losses)
762

763
        logger.info(f"epoch {epoch}: eval_loss: {eval_loss}")
764

765
        if args.with_tracking:
766
            accelerator.log(
767
                {
768
                    "eval_loss": eval_loss,
769
                    "train_loss": total_loss.item() / len(train_dataloader),
770
                    "epoch": epoch,
771
                    "step": completed_steps,
772
                },
773
                step=completed_steps,
774
            )
775

776
        if args.push_to_hub and epoch < args.num_train_epochs - 1:
777
            accelerator.wait_for_everyone()
778
            unwrapped_model = accelerator.unwrap_model(model)
779
            unwrapped_model.save_pretrained(
780
                args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
781
            )
782
            if accelerator.is_main_process:
783
                image_processor.save_pretrained(args.output_dir)
784
                repo.push_to_hub(
785
                    commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
786
                )
787

788
        if args.checkpointing_steps == "epoch":
789
            output_dir = f"epoch_{epoch}"
790
            if args.output_dir is not None:
791
                output_dir = os.path.join(args.output_dir, output_dir)
792
            accelerator.save_state(output_dir)
793

794
    if args.with_tracking:
795
        accelerator.end_training()
796

797
    if args.output_dir is not None:
798
        accelerator.wait_for_everyone()
799
        unwrapped_model = accelerator.unwrap_model(model)
800
        unwrapped_model.save_pretrained(
801
            args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
802
        )
803
        if accelerator.is_main_process:
804
            image_processor.save_pretrained(args.output_dir)
805
            if args.push_to_hub:
806
                repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
807

808

809
if __name__ == "__main__":
810
    main()
811

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.