transformers
810 строк · 29.9 Кб
1#!/usr/bin/env python
2# coding=utf-8
3# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15
16import argparse
17import logging
18import math
19import os
20import warnings
21from pathlib import Path
22
23import datasets
24import numpy as np
25import torch
26from accelerate import Accelerator, DistributedType
27from accelerate.utils import set_seed
28from datasets import load_dataset
29from huggingface_hub import Repository, create_repo
30from torch.utils.data import DataLoader
31from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor
32from tqdm.auto import tqdm
33
34import transformers
35from transformers import (
36CONFIG_MAPPING,
37IMAGE_PROCESSOR_MAPPING,
38MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
39AutoConfig,
40AutoImageProcessor,
41AutoModelForMaskedImageModeling,
42SchedulerType,
43get_scheduler,
44)
45from transformers.utils import check_min_version, send_example_telemetry
46from transformers.utils.versions import require_version
47
48
49""" Pre-training a 🤗 Transformers model for simple masked image modeling (SimMIM)
50without using HuggingFace Trainer.
51Any model supported by the AutoModelForMaskedImageModeling API can be used.
52"""
53
54logger = logging.getLogger(__name__)
55
56# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
57check_min_version("4.39.0.dev0")
58
59require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
60
61MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING.keys())
62MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
63
64
65def parse_args():
66parser = argparse.ArgumentParser(
67description="Finetune a transformers model on a simple Masked Image Modeling task"
68)
69parser.add_argument(
70"--dataset_name",
71type=str,
72default="cifar10",
73help="Name of a dataset from the datasets package",
74)
75parser.add_argument(
76"--dataset_config_name",
77type=str,
78default=None,
79help="The configuration name of the dataset to use (via the datasets library).",
80)
81parser.add_argument(
82"--image_column_name",
83type=str,
84default=None,
85help="The column name of the images in the files. If not set, will try to use 'image' or 'img'.",
86)
87parser.add_argument(
88"--train_dir",
89type=str,
90default=None,
91help="A folder containing the training data.",
92)
93parser.add_argument(
94"--validation_dir",
95type=None,
96default=None,
97help="A folder containing the validation data.",
98)
99parser.add_argument(
100"--train_val_split",
101type=float,
102default=0.15,
103help="Percent to split off of train for validation.",
104)
105parser.add_argument(
106"--mask_patch_size",
107type=int,
108default=32,
109help="The size of the square patches to use for masking.",
110)
111parser.add_argument(
112"--mask_ratio",
113type=float,
114default=0.6,
115help="Percentage of patches to mask.",
116)
117parser.add_argument(
118"--max_train_samples",
119type=int,
120default=None,
121help=(
122"For debugging purposes or quicker training, truncate the number of training examples to this "
123"value if set."
124),
125)
126parser.add_argument(
127"--max_eval_samples",
128type=int,
129default=None,
130help=(
131"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
132"value if set."
133),
134)
135parser.add_argument(
136"--model_name_or_path",
137type=str,
138default=None,
139help=(
140"The model checkpoint for weights initialization. Can be a local path to a pytorch_model.bin or a "
141"checkpoint identifier on the hub. "
142"Don't set if you want to train a model from scratch."
143),
144)
145parser.add_argument(
146"--model_type",
147type=str,
148default=None,
149help="If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES),
150)
151parser.add_argument(
152"--config_name_or_path",
153type=str,
154default=None,
155help="Pretrained config name or path if not the same as model_name",
156)
157parser.add_argument(
158"--config_overrides",
159type=str,
160default=None,
161help=(
162"Override some existing default config settings when a model is trained from scratch. Example: "
163"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
164),
165)
166parser.add_argument(
167"--cache_dir",
168type=str,
169default=None,
170help="Where do you want to store (cache) the pretrained models/datasets downloaded from the hub",
171)
172parser.add_argument(
173"--model_revision",
174type=str,
175default="main",
176help="The specific model version to use (can be a branch name, tag name or commit id).",
177)
178parser.add_argument(
179"--gradient_accumulation_steps",
180type=int,
181default=1,
182help="Number of updates steps to accumulate before performing a backward/update pass.",
183)
184parser.add_argument(
185"--image_processor_name",
186type=str,
187default=None,
188help="Name or path of preprocessor config.",
189)
190parser.add_argument(
191"--token",
192type=str,
193default=None,
194help=(
195"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
196"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
197),
198)
199parser.add_argument(
200"--use_auth_token",
201type=bool,
202default=None,
203help="The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
204)
205parser.add_argument(
206"--trust_remote_code",
207type=bool,
208default=False,
209help=(
210"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
211"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
212"execute code present on the Hub on your local machine."
213),
214)
215parser.add_argument(
216"--image_size",
217type=int,
218default=None,
219help="The size (resolution) of each image. If not specified, will use `image_size` of the configuration.",
220)
221parser.add_argument(
222"--patch_size",
223type=int,
224default=None,
225help="The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration.",
226)
227parser.add_argument(
228"--encoder_stride",
229type=int,
230default=None,
231help={"help": "Stride to use for the encoder."},
232)
233parser.add_argument(
234"--push_to_hub",
235action="store_true",
236help="Whether or not to push the model to the Hub.",
237)
238parser.add_argument(
239"--with_tracking",
240action="store_true",
241help="Whether to enable experiment trackers for logging.",
242)
243parser.add_argument(
244"--report_to",
245type=str,
246default="all",
247help=(
248'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
249' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. '
250"Only applicable when `--with_tracking` is passed."
251),
252)
253parser.add_argument(
254"--seed",
255type=int,
256default=None,
257help="A seed for reproducible training.",
258)
259parser.add_argument(
260"--per_device_train_batch_size",
261type=int,
262default=8,
263help="Batch size (per device) for the training dataloader.",
264)
265parser.add_argument(
266"--learning_rate",
267type=float,
268default=5e-5,
269help="The initial learning rate for [`AdamW`] optimizer.",
270)
271parser.add_argument(
272"--weight_decay",
273type=float,
274default=0.0,
275help="Weight decay to use.",
276)
277parser.add_argument(
278"--num_train_epochs",
279type=float,
280default=3.0,
281help="Total number of training epochs to perform (if not an integer, will perform the decimal part percents of the last epoch before stopping training).",
282)
283parser.add_argument(
284"--max_train_steps",
285type=int,
286default=None,
287help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
288)
289parser.add_argument(
290"--lr_scheduler_type",
291type=SchedulerType,
292default="linear",
293help="The scheduler type to use.",
294choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
295)
296parser.add_argument(
297"--num_warmup_steps",
298type=int,
299default=0,
300help="Number of steps for the warmup in the lr scheduler.",
301)
302parser.add_argument(
303"--checkpointing_steps",
304type=str,
305default=None,
306help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
307)
308parser.add_argument(
309"--resume_from_checkpoint",
310type=str,
311default=None,
312help="If the training should continue from a checkpoint folder.",
313)
314parser.add_argument(
315"--per_device_eval_batch_size",
316type=int,
317default=8,
318help="Batch size (per device) for the evaluation dataloader.",
319)
320parser.add_argument(
321"--output_dir",
322type=str,
323default=None,
324help="Where to store the final model.",
325)
326args = parser.parse_args()
327
328# Sanity checks
329data_files = {}
330if args.train_dir is not None:
331data_files["train"] = args.train_dir
332if args.validation_dir is not None:
333data_files["val"] = args.validation_dir
334args.data_files = data_files if data_files else None
335
336if args.push_to_hub:
337assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
338
339return args
340
341
342class MaskGenerator:
343"""
344A class to generate boolean masks for the pretraining task.
345
346A mask is a 1D tensor of shape (model_patch_size**2,) where the value is either 0 or 1,
347where 1 indicates "masked".
348"""
349
350def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
351self.input_size = input_size
352self.mask_patch_size = mask_patch_size
353self.model_patch_size = model_patch_size
354self.mask_ratio = mask_ratio
355
356if self.input_size % self.mask_patch_size != 0:
357raise ValueError("Input size must be divisible by mask patch size")
358if self.mask_patch_size % self.model_patch_size != 0:
359raise ValueError("Mask patch size must be divisible by model patch size")
360
361self.rand_size = self.input_size // self.mask_patch_size
362self.scale = self.mask_patch_size // self.model_patch_size
363
364self.token_count = self.rand_size**2
365self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
366
367def __call__(self):
368mask_idx = np.random.permutation(self.token_count)[: self.mask_count]
369mask = np.zeros(self.token_count, dtype=int)
370mask[mask_idx] = 1
371
372mask = mask.reshape((self.rand_size, self.rand_size))
373mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
374
375return torch.tensor(mask.flatten())
376
377
378def collate_fn(examples):
379pixel_values = torch.stack([example["pixel_values"] for example in examples])
380mask = torch.stack([example["mask"] for example in examples])
381return {"pixel_values": pixel_values, "bool_masked_pos": mask}
382
383
384def main():
385args = parse_args()
386
387if args.use_auth_token is not None:
388warnings.warn(
389"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
390FutureWarning,
391)
392if args.token is not None:
393raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
394args.token = args.use_auth_token
395
396# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
397# information sent is the one passed as arguments along with your Python/PyTorch versions.
398send_example_telemetry("run_mim_no_trainer", args)
399
400# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
401# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
402# in the environment
403accelerator_log_kwargs = {}
404
405if args.with_tracking:
406accelerator_log_kwargs["log_with"] = args.report_to
407accelerator_log_kwargs["project_dir"] = args.output_dir
408
409accelerator = Accelerator(
410gradient_accumulation_steps=args.gradient_accumulation_steps,
411**accelerator_log_kwargs,
412)
413
414# Make one log on every process with the configuration for debugging.
415logging.basicConfig(
416format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
417datefmt="%m/%d/%Y %H:%M:%S",
418level=logging.INFO,
419)
420logger.info(accelerator.state)
421if accelerator.is_local_main_process:
422datasets.utils.logging.set_verbosity_warning()
423transformers.utils.logging.set_verbosity_info()
424else:
425datasets.utils.logging.set_verbosity_error()
426transformers.utils.logging.set_verbosity_error()
427
428# If passed along, set the training seed now.
429if args.seed is not None:
430set_seed(args.seed)
431
432# Handle the repository creation
433if accelerator.is_main_process:
434if args.push_to_hub:
435# Retrieve of infer repo_name
436repo_name = args.hub_model_id
437if repo_name is None:
438repo_name = Path(args.output_dir).absolute().name
439# Create repo and retrieve repo_id
440repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
441# Clone repo locally
442repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
443
444with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
445if "step_*" not in gitignore:
446gitignore.write("step_*\n")
447if "epoch_*" not in gitignore:
448gitignore.write("epoch_*\n")
449elif args.output_dir is not None:
450os.makedirs(args.output_dir, exist_ok=True)
451accelerator.wait_for_everyone()
452
453# Initialize our dataset.
454ds = load_dataset(
455args.dataset_name,
456args.dataset_config_name,
457data_files=args.data_files,
458cache_dir=args.cache_dir,
459token=args.token,
460)
461
462# If we don't have a validation split, split off a percentage of train as validation.
463args.train_val_split = None if "validation" in ds.keys() else args.train_val_split
464if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
465split = ds["train"].train_test_split(args.train_val_split)
466ds["train"] = split["train"]
467ds["validation"] = split["test"]
468
469# Create config
470# Distributed training:
471# The .from_pretrained methods guarantee that only one local process can concurrently
472# download model & vocab.
473config_kwargs = {
474"cache_dir": args.cache_dir,
475"revision": args.model_revision,
476"token": args.token,
477"trust_remote_code": args.trust_remote_code,
478}
479if args.config_name_or_path:
480config = AutoConfig.from_pretrained(args.config_name_or_path, **config_kwargs)
481elif args.model_name_or_path:
482config = AutoConfig.from_pretrained(args.model_name_or_path, **config_kwargs)
483else:
484config = CONFIG_MAPPING[args.model_type]()
485logger.warning("You are instantiating a new config instance from scratch.")
486if args.config_overrides is not None:
487logger.info(f"Overriding config: {args.config_overrides}")
488config.update_from_string(args.config_overrides)
489logger.info(f"New config: {config}")
490
491# make sure the decoder_type is "simmim" (only relevant for BEiT)
492if hasattr(config, "decoder_type"):
493config.decoder_type = "simmim"
494
495# adapt config
496args.image_size = args.image_size if args.image_size is not None else config.image_size
497args.patch_size = args.patch_size if args.patch_size is not None else config.patch_size
498args.encoder_stride = args.encoder_stride if args.encoder_stride is not None else config.encoder_stride
499
500config.update(
501{
502"image_size": args.image_size,
503"patch_size": args.patch_size,
504"encoder_stride": args.encoder_stride,
505}
506)
507
508# create image processor
509if args.image_processor_name:
510image_processor = AutoImageProcessor.from_pretrained(args.image_processor_name, **config_kwargs)
511elif args.model_name_or_path:
512image_processor = AutoImageProcessor.from_pretrained(args.model_name_or_path, **config_kwargs)
513else:
514IMAGE_PROCESSOR_TYPES = {
515conf.model_type: image_processor_class for conf, image_processor_class in IMAGE_PROCESSOR_MAPPING.items()
516}
517image_processor = IMAGE_PROCESSOR_TYPES[args.model_type]()
518
519# create model
520if args.model_name_or_path:
521model = AutoModelForMaskedImageModeling.from_pretrained(
522args.model_name_or_path,
523from_tf=bool(".ckpt" in args.model_name_or_path),
524config=config,
525cache_dir=args.cache_dir,
526revision=args.model_revision,
527token=args.token,
528trust_remote_code=args.trust_remote_code,
529)
530else:
531logger.info("Training new model from scratch")
532model = AutoModelForMaskedImageModeling.from_config(
533config,
534token=args.token,
535trust_remote_code=args.trust_remote_code,
536)
537
538column_names = ds["train"].column_names
539
540if args.image_column_name is not None:
541image_column_name = args.image_column_name
542elif "image" in column_names:
543image_column_name = "image"
544elif "img" in column_names:
545image_column_name = "img"
546else:
547image_column_name = column_names[0]
548
549# transformations as done in original SimMIM paper
550# source: https://github.com/microsoft/SimMIM/blob/main/data/data_simmim.py
551transforms = Compose(
552[
553Lambda(lambda img: img.convert("RGB")),
554RandomResizedCrop(args.image_size, scale=(0.67, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
555RandomHorizontalFlip(),
556ToTensor(),
557Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
558]
559)
560
561# create mask generator
562mask_generator = MaskGenerator(
563input_size=args.image_size,
564mask_patch_size=args.mask_patch_size,
565model_patch_size=args.patch_size,
566mask_ratio=args.mask_ratio,
567)
568
569def preprocess_images(examples):
570"""Preprocess a batch of images by applying transforms + creating a corresponding mask, indicating
571which patches to mask."""
572
573examples["pixel_values"] = [transforms(image) for image in examples[image_column_name]]
574examples["mask"] = [mask_generator() for i in range(len(examples[image_column_name]))]
575
576return examples
577
578if args.max_train_samples is not None:
579ds["train"] = ds["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
580# Set the training transforms
581ds["train"].set_transform(preprocess_images)
582
583if args.max_eval_samples is not None:
584ds["validation"] = ds["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples))
585# Set the validation transforms
586ds["validation"].set_transform(preprocess_images)
587
588# DataLoaders creation:
589train_dataloader = DataLoader(
590ds["train"],
591shuffle=True,
592collate_fn=collate_fn,
593batch_size=args.per_device_train_batch_size,
594)
595eval_dataloader = DataLoader(
596ds["validation"],
597collate_fn=collate_fn,
598batch_size=args.per_device_eval_batch_size,
599)
600
601# Optimizer
602# Split weights in two groups, one with weight decay and the other not.
603no_decay = ["bias", "LayerNorm.weight"]
604optimizer_grouped_parameters = [
605{
606"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
607"weight_decay": args.weight_decay,
608},
609{
610"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
611"weight_decay": 0.0,
612},
613]
614optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
615
616# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
617# shorter in multiprocess)
618
619# Scheduler and math around the number of training steps.
620overrode_max_train_steps = False
621num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
622if args.max_train_steps is None:
623args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
624overrode_max_train_steps = True
625
626lr_scheduler = get_scheduler(
627name=args.lr_scheduler_type,
628optimizer=optimizer,
629num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
630num_training_steps=args.max_train_steps
631if overrode_max_train_steps
632else args.max_train_steps * accelerator.num_processes,
633)
634
635# Prepare everything with our `accelerator`.
636model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
637model,
638optimizer,
639train_dataloader,
640eval_dataloader,
641lr_scheduler,
642)
643
644# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
645if accelerator.distributed_type == DistributedType.TPU:
646model.tie_weights()
647
648# We need to recalculate our total training steps as the size of the training dataloader may have changed.
649num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
650if overrode_max_train_steps:
651args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
652# Afterwards we recalculate our number of training epochs
653args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
654
655# Figure out how many steps we should save the Accelerator states
656checkpointing_steps = args.checkpointing_steps
657if checkpointing_steps is not None and checkpointing_steps.isdigit():
658checkpointing_steps = int(checkpointing_steps)
659
660# We need to initialize the trackers we use, and also store our configuration.
661# The trackers initializes automatically on the main process.
662if args.with_tracking:
663experiment_config = vars(args)
664# TensorBoard cannot log Enums, need the raw value
665experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
666accelerator.init_trackers("mim_no_trainer", experiment_config)
667
668# Train!
669total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
670
671logger.info("***** Running training *****")
672logger.info(f" Num examples = {len(ds['train'])}")
673logger.info(f" Num Epochs = {args.num_train_epochs}")
674logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
675logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
676logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
677logger.info(f" Total optimization steps = {args.max_train_steps}")
678# Only show the progress bar once on each machine.
679progress_bar = tqdm(range(int(args.max_train_steps)), disable=not accelerator.is_local_main_process)
680completed_steps = 0
681starting_epoch = 0
682
683# Potentially load in the weights and states from a previous save
684if args.resume_from_checkpoint:
685if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
686checkpoint_path = args.resume_from_checkpoint
687path = os.path.basename(args.resume_from_checkpoint)
688else:
689# Get the most recent checkpoint
690dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
691dirs.sort(key=os.path.getctime)
692path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
693checkpoint_path = path
694path = os.path.basename(checkpoint_path)
695
696accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
697accelerator.load_state(checkpoint_path)
698# Extract `epoch_{i}` or `step_{i}`
699training_difference = os.path.splitext(path)[0]
700
701if "epoch" in training_difference:
702starting_epoch = int(training_difference.replace("epoch_", "")) + 1
703resume_step = None
704completed_steps = starting_epoch * num_update_steps_per_epoch
705else:
706# need to multiply `gradient_accumulation_steps` to reflect real steps
707resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
708starting_epoch = resume_step // len(train_dataloader)
709completed_steps = resume_step // args.gradient_accumulation_steps
710resume_step -= starting_epoch * len(train_dataloader)
711
712# update the progress_bar if load from checkpoint
713progress_bar.update(completed_steps)
714
715for epoch in range(starting_epoch, args.num_train_epochs):
716model.train()
717if args.with_tracking:
718total_loss = 0
719if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
720# We skip the first `n` batches in the dataloader when resuming from a checkpoint
721active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
722else:
723active_dataloader = train_dataloader
724for step, batch in enumerate(active_dataloader):
725with accelerator.accumulate(model):
726outputs = model(**batch)
727loss = outputs.loss
728# We keep track of the loss at each epoch
729if args.with_tracking:
730total_loss += loss.detach().float()
731accelerator.backward(loss)
732optimizer.step()
733lr_scheduler.step()
734optimizer.zero_grad()
735
736# Checks if the accelerator has performed an optimization step behind the scenes
737if accelerator.sync_gradients:
738progress_bar.update(1)
739completed_steps += 1
740
741if isinstance(checkpointing_steps, int):
742if completed_steps % checkpointing_steps == 0:
743output_dir = f"step_{completed_steps}"
744if args.output_dir is not None:
745output_dir = os.path.join(args.output_dir, output_dir)
746accelerator.save_state(output_dir)
747
748if completed_steps >= args.max_train_steps:
749break
750
751model.eval()
752losses = []
753for step, batch in enumerate(eval_dataloader):
754with torch.no_grad():
755outputs = model(**batch)
756
757loss = outputs.loss
758losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))
759
760losses = torch.cat(losses)
761eval_loss = torch.mean(losses)
762
763logger.info(f"epoch {epoch}: eval_loss: {eval_loss}")
764
765if args.with_tracking:
766accelerator.log(
767{
768"eval_loss": eval_loss,
769"train_loss": total_loss.item() / len(train_dataloader),
770"epoch": epoch,
771"step": completed_steps,
772},
773step=completed_steps,
774)
775
776if args.push_to_hub and epoch < args.num_train_epochs - 1:
777accelerator.wait_for_everyone()
778unwrapped_model = accelerator.unwrap_model(model)
779unwrapped_model.save_pretrained(
780args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
781)
782if accelerator.is_main_process:
783image_processor.save_pretrained(args.output_dir)
784repo.push_to_hub(
785commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
786)
787
788if args.checkpointing_steps == "epoch":
789output_dir = f"epoch_{epoch}"
790if args.output_dir is not None:
791output_dir = os.path.join(args.output_dir, output_dir)
792accelerator.save_state(output_dir)
793
794if args.with_tracking:
795accelerator.end_training()
796
797if args.output_dir is not None:
798accelerator.wait_for_everyone()
799unwrapped_model = accelerator.unwrap_model(model)
800unwrapped_model.save_pretrained(
801args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
802)
803if accelerator.is_main_process:
804image_processor.save_pretrained(args.output_dir)
805if args.push_to_hub:
806repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
807
808
809if __name__ == "__main__":
810main()
811