peft
1264 строки · 49.9 Кб
1import argparse
2import gc
3import hashlib
4import itertools
5import logging
6import math
7import os
8import threading
9import warnings
10from pathlib import Path
11from typing import Optional, Union
12
13import datasets
14import diffusers
15import numpy as np
16import psutil
17import torch
18import torch.nn.functional as F
19import torch.utils.checkpoint
20import transformers
21from accelerate import Accelerator
22from accelerate.logging import get_logger
23from accelerate.utils import set_seed
24from diffusers import (
25AutoencoderKL,
26DDPMScheduler,
27DiffusionPipeline,
28DPMSolverMultistepScheduler,
29UNet2DConditionModel,
30)
31from diffusers.optimization import get_scheduler
32from diffusers.utils import check_min_version
33from diffusers.utils.import_utils import is_xformers_available
34from huggingface_hub import HfFolder, Repository, whoami
35from PIL import Image
36from torch.utils.data import Dataset
37from torchvision import transforms
38from tqdm.auto import tqdm
39from transformers import AutoTokenizer, PretrainedConfig
40
41from peft import LoHaConfig, LoKrConfig, LoraConfig, get_peft_model
42
43
44# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
45check_min_version("0.10.0.dev0")
46
47logger = get_logger(__name__)
48
49UNET_TARGET_MODULES = [
50"to_q",
51"to_k",
52"to_v",
53"proj",
54"proj_in",
55"proj_out",
56"conv",
57"conv1",
58"conv2",
59"conv_shortcut",
60"to_out.0",
61"time_emb_proj",
62"ff.net.2",
63]
64
65TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"]
66
67
68def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
69text_encoder_config = PretrainedConfig.from_pretrained(
70pretrained_model_name_or_path,
71subfolder="text_encoder",
72revision=revision,
73)
74model_class = text_encoder_config.architectures[0]
75
76if model_class == "CLIPTextModel":
77from transformers import CLIPTextModel
78
79return CLIPTextModel
80elif model_class == "RobertaSeriesModelWithTransformation":
81from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
82
83return RobertaSeriesModelWithTransformation
84else:
85raise ValueError(f"{model_class} is not supported.")
86
87
88def create_unet_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, LoHaConfig, LoKrConfig]:
89if args.adapter == "full":
90raise ValueError("Cannot create unet adapter config for full parameter")
91
92if args.adapter == "lora":
93config = LoraConfig(
94r=args.unet_r,
95lora_alpha=args.unet_alpha,
96target_modules=UNET_TARGET_MODULES,
97lora_dropout=args.unet_dropout,
98bias=args.unet_bias,
99init_lora_weights=True,
100)
101elif args.adapter == "loha":
102config = LoHaConfig(
103r=args.unet_r,
104alpha=args.unet_alpha,
105target_modules=UNET_TARGET_MODULES,
106rank_dropout=args.unet_rank_dropout,
107module_dropout=args.unet_module_dropout,
108use_effective_conv2d=args.unet_use_effective_conv2d,
109init_weights=True,
110)
111elif args.adapter == "lokr":
112config = LoKrConfig(
113r=args.unet_r,
114alpha=args.unet_alpha,
115target_modules=UNET_TARGET_MODULES,
116rank_dropout=args.unet_rank_dropout,
117module_dropout=args.unet_module_dropout,
118use_effective_conv2d=args.unet_use_effective_conv2d,
119decompose_both=args.unet_decompose_both,
120decompose_factor=args.unet_decompose_factor,
121init_weights=True,
122)
123else:
124raise ValueError(f"Unknown adapter type {args.adapter}")
125
126return config
127
128
129def create_text_encoder_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, LoHaConfig, LoKrConfig]:
130if args.adapter == "full":
131raise ValueError("Cannot create text_encoder adapter config for full parameter")
132
133if args.adapter == "lora":
134config = LoraConfig(
135r=args.te_r,
136lora_alpha=args.te_alpha,
137target_modules=TEXT_ENCODER_TARGET_MODULES,
138lora_dropout=args.te_dropout,
139bias=args.te_bias,
140init_lora_weights=True,
141)
142elif args.adapter == "loha":
143config = LoHaConfig(
144r=args.te_r,
145alpha=args.te_alpha,
146target_modules=TEXT_ENCODER_TARGET_MODULES,
147rank_dropout=args.te_rank_dropout,
148module_dropout=args.te_module_dropout,
149init_weights=True,
150)
151elif args.adapter == "lokr":
152config = LoKrConfig(
153r=args.te_r,
154alpha=args.te_alpha,
155target_modules=TEXT_ENCODER_TARGET_MODULES,
156rank_dropout=args.te_rank_dropout,
157module_dropout=args.te_module_dropout,
158decompose_both=args.te_decompose_both,
159decompose_factor=args.te_decompose_factor,
160init_weights=True,
161)
162else:
163raise ValueError(f"Unknown adapter type {args.adapter}")
164
165return config
166
167
168def parse_args(input_args=None):
169parser = argparse.ArgumentParser(description="Simple example of a training script.")
170parser.add_argument(
171"--pretrained_model_name_or_path",
172type=str,
173default=None,
174required=True,
175help="Path to pretrained model or model identifier from huggingface.co/models.",
176)
177parser.add_argument(
178"--revision",
179type=str,
180default=None,
181required=False,
182help="Revision of pretrained model identifier from huggingface.co/models.",
183)
184parser.add_argument(
185"--tokenizer_name",
186type=str,
187default=None,
188help="Pretrained tokenizer name or path if not the same as model_name",
189)
190parser.add_argument(
191"--instance_data_dir",
192type=str,
193default=None,
194required=True,
195help="A folder containing the training data of instance images.",
196)
197parser.add_argument(
198"--class_data_dir",
199type=str,
200default=None,
201required=False,
202help="A folder containing the training data of class images.",
203)
204parser.add_argument(
205"--instance_prompt",
206type=str,
207default=None,
208required=True,
209help="The prompt with identifier specifying the instance",
210)
211parser.add_argument(
212"--class_prompt",
213type=str,
214default=None,
215help="The prompt to specify images in the same class as provided instance images.",
216)
217parser.add_argument(
218"--with_prior_preservation",
219default=False,
220action="store_true",
221help="Flag to add prior preservation loss.",
222)
223parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
224parser.add_argument(
225"--num_class_images",
226type=int,
227default=100,
228help=(
229"Minimal class images for prior preservation loss. If there are not enough images already present in"
230" class_data_dir, additional images will be sampled with class_prompt."
231),
232)
233parser.add_argument(
234"--validation_prompt",
235type=str,
236default=None,
237help="A prompt that is used during validation to verify that the model is learning.",
238)
239parser.add_argument(
240"--num_validation_images",
241type=int,
242default=4,
243help="Number of images that should be generated during validation with `validation_prompt`.",
244)
245parser.add_argument(
246"--validation_steps",
247type=int,
248default=100,
249help=(
250"Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt"
251" `args.validation_prompt` multiple times: `args.num_validation_images`."
252),
253)
254parser.add_argument(
255"--output_dir",
256type=str,
257default="text-inversion-model",
258help="The output directory where the model predictions and checkpoints will be written.",
259)
260parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
261parser.add_argument(
262"--resolution",
263type=int,
264default=512,
265help=(
266"The resolution for input images, all the images in the train/validation dataset will be resized to this"
267" resolution"
268),
269)
270parser.add_argument(
271"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
272)
273parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
274
275parser.add_argument(
276"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
277)
278parser.add_argument(
279"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
280)
281parser.add_argument("--num_train_epochs", type=int, default=1)
282parser.add_argument(
283"--max_train_steps",
284type=int,
285default=None,
286help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
287)
288parser.add_argument(
289"--checkpointing_steps",
290type=int,
291default=500,
292help=(
293"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
294" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
295" training using `--resume_from_checkpoint`."
296),
297)
298parser.add_argument(
299"--resume_from_checkpoint",
300type=str,
301default=None,
302help=(
303"Whether training should be resumed from a previous checkpoint. Use a path saved by"
304' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
305),
306)
307parser.add_argument(
308"--gradient_accumulation_steps",
309type=int,
310default=1,
311help="Number of updates steps to accumulate before performing a backward/update pass.",
312)
313parser.add_argument(
314"--gradient_checkpointing",
315action="store_true",
316help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
317)
318parser.add_argument(
319"--learning_rate",
320type=float,
321default=5e-6,
322help="Initial learning rate (after the potential warmup period) to use.",
323)
324parser.add_argument(
325"--scale_lr",
326action="store_true",
327default=False,
328help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
329)
330parser.add_argument(
331"--lr_scheduler",
332type=str,
333default="constant",
334help=(
335'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
336' "constant", "constant_with_warmup"]'
337),
338)
339parser.add_argument(
340"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
341)
342parser.add_argument(
343"--lr_num_cycles",
344type=int,
345default=1,
346help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
347)
348parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
349parser.add_argument(
350"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
351)
352parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
353parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
354parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
355parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
356parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
357parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
358parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
359parser.add_argument(
360"--hub_model_id",
361type=str,
362default=None,
363help="The name of the repository to keep in sync with the local `output_dir`.",
364)
365parser.add_argument(
366"--logging_dir",
367type=str,
368default="logs",
369help=(
370"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
371" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
372),
373)
374parser.add_argument(
375"--allow_tf32",
376action="store_true",
377help=(
378"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
379" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
380),
381)
382parser.add_argument(
383"--report_to",
384type=str,
385default="tensorboard",
386help=(
387'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
388' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
389),
390)
391parser.add_argument(
392"--wandb_key",
393type=str,
394default=None,
395help=("If report to option is set to wandb, api-key for wandb used for login to wandb "),
396)
397parser.add_argument(
398"--wandb_project_name",
399type=str,
400default=None,
401help=("If report to option is set to wandb, project name in wandb for log tracking "),
402)
403parser.add_argument(
404"--mixed_precision",
405type=str,
406default=None,
407choices=["no", "fp16", "bf16"],
408help=(
409"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
410" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
411" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
412),
413)
414parser.add_argument(
415"--prior_generation_precision",
416type=str,
417default=None,
418choices=["no", "fp32", "fp16", "bf16"],
419help=(
420"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
421" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
422),
423)
424parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
425parser.add_argument(
426"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
427)
428
429# Adapter arguments
430subparsers = parser.add_subparsers(dest="adapter")
431
432# Dummy subparser to train whole model
433subparsers.add_parser("full", help="Train full model without adapters")
434
435# LoRA adapter
436lora = subparsers.add_parser("lora", help="Use LoRA adapter")
437lora.add_argument("--unet_r", type=int, default=8, help="LoRA rank for unet")
438lora.add_argument("--unet_alpha", type=int, default=8, help="LoRA alpha for unet")
439lora.add_argument("--unet_dropout", type=float, default=0.0, help="LoRA dropout probability for unet")
440lora.add_argument(
441"--unet_bias",
442type=str,
443default="none",
444help="Bias type for LoRA. Can be 'none', 'all' or 'lora_only'",
445)
446lora.add_argument(
447"--te_r", type=int, default=8, help="LoRA rank for text_encoder, only used if `train_text_encoder` is True"
448)
449lora.add_argument(
450"--te_alpha",
451type=int,
452default=8,
453help="LoRA alpha for text_encoder, only used if `train_text_encoder` is True",
454)
455lora.add_argument(
456"--te_dropout",
457type=float,
458default=0.0,
459help="LoRA dropout probability for text_encoder, only used if `train_text_encoder` is True",
460)
461lora.add_argument(
462"--te_bias",
463type=str,
464default="none",
465help="Bias type for LoRA. Can be 'none', 'all' or 'lora_only', only used if `train_text_encoder` is True",
466)
467
468# LoHa adapter
469loha = subparsers.add_parser("loha", help="Use LoHa adapter")
470loha.add_argument("--unet_r", type=int, default=8, help="LoHa rank for unet")
471loha.add_argument("--unet_alpha", type=int, default=8, help="LoHa alpha for unet")
472loha.add_argument("--unet_rank_dropout", type=float, default=0.0, help="LoHa rank_dropout probability for unet")
473loha.add_argument(
474"--unet_module_dropout", type=float, default=0.0, help="LoHa module_dropout probability for unet"
475)
476loha.add_argument(
477"--unet_use_effective_conv2d",
478action="store_true",
479help="Use parameter effective decomposition in unet for Conv2d 3x3 with ksize > 1",
480)
481loha.add_argument(
482"--te_r", type=int, default=8, help="LoHa rank for text_encoder, only used if `train_text_encoder` is True"
483)
484loha.add_argument(
485"--te_alpha",
486type=int,
487default=8,
488help="LoHa alpha for text_encoder, only used if `train_text_encoder` is True",
489)
490loha.add_argument(
491"--te_rank_dropout",
492type=float,
493default=0.0,
494help="LoHa rank_dropout probability for text_encoder, only used if `train_text_encoder` is True",
495)
496loha.add_argument(
497"--te_module_dropout",
498type=float,
499default=0.0,
500help="LoHa module_dropout probability for text_encoder, only used if `train_text_encoder` is True",
501)
502
503# LoKr adapter
504lokr = subparsers.add_parser("lokr", help="Use LoKr adapter")
505lokr.add_argument("--unet_r", type=int, default=8, help="LoKr rank for unet")
506lokr.add_argument("--unet_alpha", type=int, default=8, help="LoKr alpha for unet")
507lokr.add_argument("--unet_rank_dropout", type=float, default=0.0, help="LoKr rank_dropout probability for unet")
508lokr.add_argument(
509"--unet_module_dropout", type=float, default=0.0, help="LoKr module_dropout probability for unet"
510)
511lokr.add_argument(
512"--unet_use_effective_conv2d",
513action="store_true",
514help="Use parameter effective decomposition in unet for Conv2d 3x3 with ksize > 1",
515)
516lokr.add_argument(
517"--unet_decompose_both", action="store_true", help="Decompose left matrix in kronecker product for unet"
518)
519lokr.add_argument(
520"--unet_decompose_factor", type=int, default=-1, help="Decompose factor in kronecker product for unet"
521)
522lokr.add_argument(
523"--te_r", type=int, default=8, help="LoKr rank for text_encoder, only used if `train_text_encoder` is True"
524)
525lokr.add_argument(
526"--te_alpha",
527type=int,
528default=8,
529help="LoKr alpha for text_encoder, only used if `train_text_encoder` is True",
530)
531lokr.add_argument(
532"--te_rank_dropout",
533type=float,
534default=0.0,
535help="LoKr rank_dropout probability for text_encoder, only used if `train_text_encoder` is True",
536)
537lokr.add_argument(
538"--te_module_dropout",
539type=float,
540default=0.0,
541help="LoKr module_dropout probability for text_encoder, only used if `train_text_encoder` is True",
542)
543lokr.add_argument(
544"--te_decompose_both",
545action="store_true",
546help="Decompose left matrix in kronecker product for text_encoder, only used if `train_text_encoder` is True",
547)
548lokr.add_argument(
549"--te_decompose_factor",
550type=int,
551default=-1,
552help="Decompose factor in kronecker product for text_encoder, only used if `train_text_encoder` is True",
553)
554
555if input_args is not None:
556args = parser.parse_args(input_args)
557else:
558args = parser.parse_args()
559
560env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
561if env_local_rank != -1 and env_local_rank != args.local_rank:
562args.local_rank = env_local_rank
563
564if args.with_prior_preservation:
565if args.class_data_dir is None:
566raise ValueError("You must specify a data directory for class images.")
567if args.class_prompt is None:
568raise ValueError("You must specify prompt for class images.")
569else:
570# logger is not available yet
571if args.class_data_dir is not None:
572warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
573if args.class_prompt is not None:
574warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
575
576return args
577
578
579# Converting Bytes to Megabytes
580def b2mb(x):
581return int(x / 2**20)
582
583
584# This context manager is used to track the peak memory usage of the process
585class TorchTracemalloc:
586def __enter__(self):
587gc.collect()
588torch.cuda.empty_cache()
589torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
590self.begin = torch.cuda.memory_allocated()
591self.process = psutil.Process()
592
593self.cpu_begin = self.cpu_mem_used()
594self.peak_monitoring = True
595peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
596peak_monitor_thread.daemon = True
597peak_monitor_thread.start()
598return self
599
600def cpu_mem_used(self):
601"""get resident set size memory for the current process"""
602return self.process.memory_info().rss
603
604def peak_monitor_func(self):
605self.cpu_peak = -1
606
607while True:
608self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
609
610# can't sleep or will not catch the peak right (this comment is here on purpose)
611# time.sleep(0.001) # 1msec
612
613if not self.peak_monitoring:
614break
615
616def __exit__(self, *exc):
617self.peak_monitoring = False
618
619gc.collect()
620torch.cuda.empty_cache()
621self.end = torch.cuda.memory_allocated()
622self.peak = torch.cuda.max_memory_allocated()
623self.used = b2mb(self.end - self.begin)
624self.peaked = b2mb(self.peak - self.begin)
625
626self.cpu_end = self.cpu_mem_used()
627self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
628self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
629# print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
630
631
632class DreamBoothDataset(Dataset):
633"""
634A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
635It pre-processes the images and the tokenizes prompts.
636"""
637
638def __init__(
639self,
640instance_data_root,
641instance_prompt,
642tokenizer,
643class_data_root=None,
644class_prompt=None,
645size=512,
646center_crop=False,
647):
648self.size = size
649self.center_crop = center_crop
650self.tokenizer = tokenizer
651
652self.instance_data_root = Path(instance_data_root)
653if not self.instance_data_root.exists():
654raise ValueError("Instance images root doesn't exists.")
655
656self.instance_images_path = list(Path(instance_data_root).iterdir())
657self.num_instance_images = len(self.instance_images_path)
658self.instance_prompt = instance_prompt
659self._length = self.num_instance_images
660
661if class_data_root is not None:
662self.class_data_root = Path(class_data_root)
663self.class_data_root.mkdir(parents=True, exist_ok=True)
664self.class_images_path = list(self.class_data_root.iterdir())
665self.num_class_images = len(self.class_images_path)
666self._length = max(self.num_class_images, self.num_instance_images)
667self.class_prompt = class_prompt
668else:
669self.class_data_root = None
670
671self.image_transforms = transforms.Compose(
672[
673transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
674transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
675transforms.ToTensor(),
676transforms.Normalize([0.5], [0.5]),
677]
678)
679
680def __len__(self):
681return self._length
682
683def __getitem__(self, index):
684example = {}
685instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
686if not instance_image.mode == "RGB":
687instance_image = instance_image.convert("RGB")
688example["instance_images"] = self.image_transforms(instance_image)
689example["instance_prompt_ids"] = self.tokenizer(
690self.instance_prompt,
691truncation=True,
692padding="max_length",
693max_length=self.tokenizer.model_max_length,
694return_tensors="pt",
695).input_ids
696
697if self.class_data_root:
698class_image = Image.open(self.class_images_path[index % self.num_class_images])
699if not class_image.mode == "RGB":
700class_image = class_image.convert("RGB")
701example["class_images"] = self.image_transforms(class_image)
702example["class_prompt_ids"] = self.tokenizer(
703self.class_prompt,
704truncation=True,
705padding="max_length",
706max_length=self.tokenizer.model_max_length,
707return_tensors="pt",
708).input_ids
709
710return example
711
712
713def collate_fn(examples, with_prior_preservation=False):
714input_ids = [example["instance_prompt_ids"] for example in examples]
715pixel_values = [example["instance_images"] for example in examples]
716
717# Concat class and instance examples for prior preservation.
718# We do this to avoid doing two forward passes.
719if with_prior_preservation:
720input_ids += [example["class_prompt_ids"] for example in examples]
721pixel_values += [example["class_images"] for example in examples]
722
723pixel_values = torch.stack(pixel_values)
724pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
725
726input_ids = torch.cat(input_ids, dim=0)
727
728batch = {
729"input_ids": input_ids,
730"pixel_values": pixel_values,
731}
732return batch
733
734
735class PromptDataset(Dataset):
736"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
737
738def __init__(self, prompt, num_samples):
739self.prompt = prompt
740self.num_samples = num_samples
741
742def __len__(self):
743return self.num_samples
744
745def __getitem__(self, index):
746example = {}
747example["prompt"] = self.prompt
748example["index"] = index
749return example
750
751
752def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
753if token is None:
754token = HfFolder.get_token()
755if organization is None:
756username = whoami(token)["name"]
757return f"{username}/{model_id}"
758else:
759return f"{organization}/{model_id}"
760
761
762def main(args):
763logging_dir = Path(args.output_dir, args.logging_dir)
764
765accelerator = Accelerator(
766gradient_accumulation_steps=args.gradient_accumulation_steps,
767mixed_precision=args.mixed_precision,
768log_with=args.report_to,
769project_dir=logging_dir,
770)
771if args.report_to == "wandb":
772import wandb
773
774wandb.login(key=args.wandb_key)
775wandb.init(project=args.wandb_project_name)
776# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
777# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
778# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
779if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
780raise ValueError(
781"Gradient accumulation is not supported when training the text encoder in distributed training. "
782"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
783)
784
785# Make one log on every process with the configuration for debugging.
786logging.basicConfig(
787format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
788datefmt="%m/%d/%Y %H:%M:%S",
789level=logging.INFO,
790)
791logger.info(accelerator.state, main_process_only=False)
792if accelerator.is_local_main_process:
793datasets.utils.logging.set_verbosity_warning()
794transformers.utils.logging.set_verbosity_warning()
795diffusers.utils.logging.set_verbosity_info()
796else:
797datasets.utils.logging.set_verbosity_error()
798transformers.utils.logging.set_verbosity_error()
799diffusers.utils.logging.set_verbosity_error()
800
801# If passed along, set the training seed now.
802if args.seed is not None:
803set_seed(args.seed)
804
805# Generate class images if prior preservation is enabled.
806if args.with_prior_preservation:
807class_images_dir = Path(args.class_data_dir)
808if not class_images_dir.exists():
809class_images_dir.mkdir(parents=True)
810cur_class_images = len(list(class_images_dir.iterdir()))
811
812if cur_class_images < args.num_class_images:
813torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
814if args.prior_generation_precision == "fp32":
815torch_dtype = torch.float32
816elif args.prior_generation_precision == "fp16":
817torch_dtype = torch.float16
818elif args.prior_generation_precision == "bf16":
819torch_dtype = torch.bfloat16
820pipeline = DiffusionPipeline.from_pretrained(
821args.pretrained_model_name_or_path,
822torch_dtype=torch_dtype,
823safety_checker=None,
824revision=args.revision,
825)
826pipeline.set_progress_bar_config(disable=True)
827
828num_new_images = args.num_class_images - cur_class_images
829logger.info(f"Number of class images to sample: {num_new_images}.")
830
831sample_dataset = PromptDataset(args.class_prompt, num_new_images)
832sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
833
834sample_dataloader = accelerator.prepare(sample_dataloader)
835pipeline.to(accelerator.device)
836
837for example in tqdm(
838sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
839):
840images = pipeline(example["prompt"]).images
841
842for i, image in enumerate(images):
843hash_image = hashlib.sha1(image.tobytes()).hexdigest()
844image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
845image.save(image_filename)
846
847del pipeline
848if torch.cuda.is_available():
849torch.cuda.empty_cache()
850
851# Handle the repository creation
852if accelerator.is_main_process:
853if args.push_to_hub:
854if args.hub_model_id is None:
855repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
856else:
857repo_name = args.hub_model_id
858repo = Repository(args.output_dir, clone_from=repo_name) # noqa: F841
859
860with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
861if "step_*" not in gitignore:
862gitignore.write("step_*\n")
863if "epoch_*" not in gitignore:
864gitignore.write("epoch_*\n")
865elif args.output_dir is not None:
866os.makedirs(args.output_dir, exist_ok=True)
867
868# Load the tokenizer
869if args.tokenizer_name:
870tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
871elif args.pretrained_model_name_or_path:
872tokenizer = AutoTokenizer.from_pretrained(
873args.pretrained_model_name_or_path,
874subfolder="tokenizer",
875revision=args.revision,
876use_fast=False,
877)
878
879# import correct text encoder class
880text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
881
882# Load scheduler and models
883noise_scheduler = DDPMScheduler(
884beta_start=0.00085,
885beta_end=0.012,
886beta_schedule="scaled_linear",
887num_train_timesteps=1000,
888) # DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
889text_encoder = text_encoder_cls.from_pretrained(
890args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
891)
892vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
893unet = UNet2DConditionModel.from_pretrained(
894args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
895)
896
897if args.adapter != "full":
898config = create_unet_adapter_config(args)
899unet = get_peft_model(unet, config)
900unet.print_trainable_parameters()
901print(unet)
902
903vae.requires_grad_(False)
904if not args.train_text_encoder:
905text_encoder.requires_grad_(False)
906elif args.train_text_encoder and args.adapter != "full":
907config = create_text_encoder_adapter_config(args)
908text_encoder = get_peft_model(text_encoder, config)
909text_encoder.print_trainable_parameters()
910print(text_encoder)
911
912if args.enable_xformers_memory_efficient_attention:
913if is_xformers_available():
914unet.enable_xformers_memory_efficient_attention()
915else:
916raise ValueError("xformers is not available. Make sure it is installed correctly")
917
918if args.gradient_checkpointing:
919unet.enable_gradient_checkpointing()
920if args.train_text_encoder and not args.adapter != "full":
921text_encoder.gradient_checkpointing_enable()
922
923# Enable TF32 for faster training on Ampere GPUs,
924# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
925if args.allow_tf32:
926torch.backends.cuda.matmul.allow_tf32 = True
927
928if args.scale_lr:
929args.learning_rate = (
930args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
931)
932
933# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
934if args.use_8bit_adam:
935try:
936import bitsandbytes as bnb
937except ImportError:
938raise ImportError(
939"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
940)
941
942optimizer_class = bnb.optim.AdamW8bit
943else:
944optimizer_class = torch.optim.AdamW
945
946# Optimizer creation
947params_to_optimize = (
948itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
949)
950optimizer = optimizer_class(
951params_to_optimize,
952lr=args.learning_rate,
953betas=(args.adam_beta1, args.adam_beta2),
954weight_decay=args.adam_weight_decay,
955eps=args.adam_epsilon,
956)
957
958# Dataset and DataLoaders creation:
959train_dataset = DreamBoothDataset(
960instance_data_root=args.instance_data_dir,
961instance_prompt=args.instance_prompt,
962class_data_root=args.class_data_dir if args.with_prior_preservation else None,
963class_prompt=args.class_prompt,
964tokenizer=tokenizer,
965size=args.resolution,
966center_crop=args.center_crop,
967)
968
969train_dataloader = torch.utils.data.DataLoader(
970train_dataset,
971batch_size=args.train_batch_size,
972shuffle=True,
973collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
974num_workers=1,
975)
976
977# Scheduler and math around the number of training steps.
978overrode_max_train_steps = False
979num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
980if args.max_train_steps is None:
981args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
982overrode_max_train_steps = True
983
984lr_scheduler = get_scheduler(
985args.lr_scheduler,
986optimizer=optimizer,
987num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
988num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
989num_cycles=args.lr_num_cycles,
990power=args.lr_power,
991)
992
993# Prepare everything with our `accelerator`.
994if args.train_text_encoder:
995unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
996unet, text_encoder, optimizer, train_dataloader, lr_scheduler
997)
998else:
999unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1000unet, optimizer, train_dataloader, lr_scheduler
1001)
1002
1003# For mixed precision training we cast the text_encoder and vae weights to half-precision
1004# as these models are only used for inference, keeping weights in full precision is not required.
1005weight_dtype = torch.float32
1006if accelerator.mixed_precision == "fp16":
1007weight_dtype = torch.float16
1008elif accelerator.mixed_precision == "bf16":
1009weight_dtype = torch.bfloat16
1010
1011# Move vae and text_encoder to device and cast to weight_dtype
1012vae.to(accelerator.device, dtype=weight_dtype)
1013if not args.train_text_encoder:
1014text_encoder.to(accelerator.device, dtype=weight_dtype)
1015
1016# We need to recalculate our total training steps as the size of the training dataloader may have changed.
1017num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1018if overrode_max_train_steps:
1019args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1020# Afterwards we recalculate our number of training epochs
1021args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1022
1023# We need to initialize the trackers we use, and also store our configuration.
1024# The trackers initializes automatically on the main process.
1025if accelerator.is_main_process:
1026accelerator.init_trackers("dreambooth", config=vars(args))
1027
1028# Train!
1029total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1030
1031logger.info("***** Running training *****")
1032logger.info(f" Num examples = {len(train_dataset)}")
1033logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1034logger.info(f" Num Epochs = {args.num_train_epochs}")
1035logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1036logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1037logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1038logger.info(f" Total optimization steps = {args.max_train_steps}")
1039global_step = 0
1040first_epoch = 0
1041
1042# Potentially load in the weights and states from a previous save
1043if args.resume_from_checkpoint:
1044if args.resume_from_checkpoint != "latest":
1045path = os.path.basename(args.resume_from_checkpoint)
1046else:
1047# Get the mos recent checkpoint
1048dirs = os.listdir(args.output_dir)
1049dirs = [d for d in dirs if d.startswith("checkpoint")]
1050dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1051path = dirs[-1]
1052accelerator.print(f"Resuming from checkpoint {path}")
1053accelerator.load_state(os.path.join(args.output_dir, path))
1054global_step = int(path.split("-")[1])
1055
1056resume_global_step = global_step * args.gradient_accumulation_steps
1057first_epoch = resume_global_step // num_update_steps_per_epoch
1058resume_step = resume_global_step % num_update_steps_per_epoch
1059
1060# Only show the progress bar once on each machine.
1061progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
1062progress_bar.set_description("Steps")
1063
1064for epoch in range(first_epoch, args.num_train_epochs):
1065unet.train()
1066if args.train_text_encoder:
1067text_encoder.train()
1068with TorchTracemalloc() as tracemalloc:
1069for step, batch in enumerate(train_dataloader):
1070# Skip steps until we reach the resumed step
1071if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1072if step % args.gradient_accumulation_steps == 0:
1073progress_bar.update(1)
1074if args.report_to == "wandb":
1075accelerator.print(progress_bar)
1076continue
1077
1078with accelerator.accumulate(unet):
1079# Convert images to latent space
1080latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1081latents = latents * 0.18215
1082
1083# Sample noise that we'll add to the latents
1084noise = torch.randn_like(latents)
1085bsz = latents.shape[0]
1086# Sample a random timestep for each image
1087timesteps = torch.randint(
10880, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
1089)
1090timesteps = timesteps.long()
1091
1092# Add noise to the latents according to the noise magnitude at each timestep
1093# (this is the forward diffusion process)
1094noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1095
1096# Get the text embedding for conditioning
1097encoder_hidden_states = text_encoder(batch["input_ids"])[0]
1098
1099# Predict the noise residual
1100model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1101
1102# Get the target for loss depending on the prediction type
1103if noise_scheduler.config.prediction_type == "epsilon":
1104target = noise
1105elif noise_scheduler.config.prediction_type == "v_prediction":
1106target = noise_scheduler.get_velocity(latents, noise, timesteps)
1107else:
1108raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1109
1110if args.with_prior_preservation:
1111# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1112model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1113target, target_prior = torch.chunk(target, 2, dim=0)
1114
1115# Compute instance loss
1116loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1117
1118# Compute prior loss
1119prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1120
1121# Add the prior loss to the instance loss.
1122loss = loss + args.prior_loss_weight * prior_loss
1123else:
1124loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1125
1126accelerator.backward(loss)
1127if accelerator.sync_gradients:
1128params_to_clip = (
1129itertools.chain(unet.parameters(), text_encoder.parameters())
1130if args.train_text_encoder
1131else unet.parameters()
1132)
1133accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1134optimizer.step()
1135lr_scheduler.step()
1136optimizer.zero_grad()
1137
1138# Checks if the accelerator has performed an optimization step behind the scenes
1139if accelerator.sync_gradients:
1140progress_bar.update(1)
1141if args.report_to == "wandb":
1142accelerator.print(progress_bar)
1143global_step += 1
1144
1145# if global_step % args.checkpointing_steps == 0:
1146# if accelerator.is_main_process:
1147# save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1148# accelerator.save_state(save_path)
1149# logger.info(f"Saved state to {save_path}")
1150
1151logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1152progress_bar.set_postfix(**logs)
1153accelerator.log(logs, step=global_step)
1154
1155if (
1156args.validation_prompt is not None
1157and (step + num_update_steps_per_epoch * epoch) % args.validation_steps == 0
1158):
1159logger.info(
1160f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1161f" {args.validation_prompt}."
1162)
1163# create pipeline
1164pipeline = DiffusionPipeline.from_pretrained(
1165args.pretrained_model_name_or_path,
1166safety_checker=None,
1167revision=args.revision,
1168)
1169# set `keep_fp32_wrapper` to True because we do not want to remove
1170# mixed precision hooks while we are still training
1171pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
1172pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
1173pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1174pipeline = pipeline.to(accelerator.device)
1175pipeline.set_progress_bar_config(disable=True)
1176
1177# Set evaliation mode
1178pipeline.unet.eval()
1179pipeline.text_encoder.eval()
1180
1181# run inference
1182if args.seed is not None:
1183generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
1184else:
1185generator = None
1186images = []
1187for _ in range(args.num_validation_images):
1188image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1189images.append(image)
1190
1191for tracker in accelerator.trackers:
1192if tracker.name == "tensorboard":
1193np_images = np.stack([np.asarray(img) for img in images])
1194tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1195if tracker.name == "wandb":
1196import wandb
1197
1198tracker.log(
1199{
1200"validation": [
1201wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1202for i, image in enumerate(images)
1203]
1204}
1205)
1206
1207# Set evaliation mode
1208pipeline.unet.train()
1209pipeline.text_encoder.train()
1210
1211del pipeline
1212torch.cuda.empty_cache()
1213
1214if global_step >= args.max_train_steps:
1215break
1216# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
1217accelerator.print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}")
1218accelerator.print(f"GPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}")
1219accelerator.print(f"GPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}")
1220accelerator.print(
1221f"GPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}"
1222)
1223
1224accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}")
1225accelerator.print(f"CPU Memory consumed at the end of the train (end-begin): {tracemalloc.cpu_used}")
1226accelerator.print(f"CPU Peak Memory consumed during the train (max-begin): {tracemalloc.cpu_peaked}")
1227accelerator.print(
1228"CPU Total Peak Memory consumed during the train (max): {}".format(
1229tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)
1230)
1231)
1232
1233# Create the pipeline using using the trained modules and save it.
1234accelerator.wait_for_everyone()
1235if accelerator.is_main_process:
1236if args.adapter != "full":
1237unwarpped_unet = accelerator.unwrap_model(unet)
1238unwarpped_unet.save_pretrained(
1239os.path.join(args.output_dir, "unet"), state_dict=accelerator.get_state_dict(unet)
1240)
1241if args.train_text_encoder:
1242unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
1243unwarpped_text_encoder.save_pretrained(
1244os.path.join(args.output_dir, "text_encoder"),
1245state_dict=accelerator.get_state_dict(text_encoder),
1246)
1247else:
1248pipeline = DiffusionPipeline.from_pretrained(
1249args.pretrained_model_name_or_path,
1250unet=accelerator.unwrap_model(unet),
1251text_encoder=accelerator.unwrap_model(text_encoder),
1252revision=args.revision,
1253)
1254pipeline.save_pretrained(args.output_dir)
1255
1256if args.push_to_hub:
1257repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1258
1259accelerator.end_training()
1260
1261
1262if __name__ == "__main__":
1263args = parse_args()
1264main(args)
1265