peft

Форк
0
/
train_dreambooth.py 
1264 строки · 49.9 Кб
1
import argparse
2
import gc
3
import hashlib
4
import itertools
5
import logging
6
import math
7
import os
8
import threading
9
import warnings
10
from pathlib import Path
11
from typing import Optional, Union
12

13
import datasets
14
import diffusers
15
import numpy as np
16
import psutil
17
import torch
18
import torch.nn.functional as F
19
import torch.utils.checkpoint
20
import transformers
21
from accelerate import Accelerator
22
from accelerate.logging import get_logger
23
from accelerate.utils import set_seed
24
from diffusers import (
25
    AutoencoderKL,
26
    DDPMScheduler,
27
    DiffusionPipeline,
28
    DPMSolverMultistepScheduler,
29
    UNet2DConditionModel,
30
)
31
from diffusers.optimization import get_scheduler
32
from diffusers.utils import check_min_version
33
from diffusers.utils.import_utils import is_xformers_available
34
from huggingface_hub import HfFolder, Repository, whoami
35
from PIL import Image
36
from torch.utils.data import Dataset
37
from torchvision import transforms
38
from tqdm.auto import tqdm
39
from transformers import AutoTokenizer, PretrainedConfig
40

41
from peft import LoHaConfig, LoKrConfig, LoraConfig, get_peft_model
42

43

44
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
45
check_min_version("0.10.0.dev0")
46

47
logger = get_logger(__name__)
48

49
UNET_TARGET_MODULES = [
50
    "to_q",
51
    "to_k",
52
    "to_v",
53
    "proj",
54
    "proj_in",
55
    "proj_out",
56
    "conv",
57
    "conv1",
58
    "conv2",
59
    "conv_shortcut",
60
    "to_out.0",
61
    "time_emb_proj",
62
    "ff.net.2",
63
]
64

65
TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"]
66

67

68
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
69
    text_encoder_config = PretrainedConfig.from_pretrained(
70
        pretrained_model_name_or_path,
71
        subfolder="text_encoder",
72
        revision=revision,
73
    )
74
    model_class = text_encoder_config.architectures[0]
75

76
    if model_class == "CLIPTextModel":
77
        from transformers import CLIPTextModel
78

79
        return CLIPTextModel
80
    elif model_class == "RobertaSeriesModelWithTransformation":
81
        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
82

83
        return RobertaSeriesModelWithTransformation
84
    else:
85
        raise ValueError(f"{model_class} is not supported.")
86

87

88
def create_unet_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, LoHaConfig, LoKrConfig]:
89
    if args.adapter == "full":
90
        raise ValueError("Cannot create unet adapter config for full parameter")
91

92
    if args.adapter == "lora":
93
        config = LoraConfig(
94
            r=args.unet_r,
95
            lora_alpha=args.unet_alpha,
96
            target_modules=UNET_TARGET_MODULES,
97
            lora_dropout=args.unet_dropout,
98
            bias=args.unet_bias,
99
            init_lora_weights=True,
100
        )
101
    elif args.adapter == "loha":
102
        config = LoHaConfig(
103
            r=args.unet_r,
104
            alpha=args.unet_alpha,
105
            target_modules=UNET_TARGET_MODULES,
106
            rank_dropout=args.unet_rank_dropout,
107
            module_dropout=args.unet_module_dropout,
108
            use_effective_conv2d=args.unet_use_effective_conv2d,
109
            init_weights=True,
110
        )
111
    elif args.adapter == "lokr":
112
        config = LoKrConfig(
113
            r=args.unet_r,
114
            alpha=args.unet_alpha,
115
            target_modules=UNET_TARGET_MODULES,
116
            rank_dropout=args.unet_rank_dropout,
117
            module_dropout=args.unet_module_dropout,
118
            use_effective_conv2d=args.unet_use_effective_conv2d,
119
            decompose_both=args.unet_decompose_both,
120
            decompose_factor=args.unet_decompose_factor,
121
            init_weights=True,
122
        )
123
    else:
124
        raise ValueError(f"Unknown adapter type {args.adapter}")
125

126
    return config
127

128

129
def create_text_encoder_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, LoHaConfig, LoKrConfig]:
130
    if args.adapter == "full":
131
        raise ValueError("Cannot create text_encoder adapter config for full parameter")
132

133
    if args.adapter == "lora":
134
        config = LoraConfig(
135
            r=args.te_r,
136
            lora_alpha=args.te_alpha,
137
            target_modules=TEXT_ENCODER_TARGET_MODULES,
138
            lora_dropout=args.te_dropout,
139
            bias=args.te_bias,
140
            init_lora_weights=True,
141
        )
142
    elif args.adapter == "loha":
143
        config = LoHaConfig(
144
            r=args.te_r,
145
            alpha=args.te_alpha,
146
            target_modules=TEXT_ENCODER_TARGET_MODULES,
147
            rank_dropout=args.te_rank_dropout,
148
            module_dropout=args.te_module_dropout,
149
            init_weights=True,
150
        )
151
    elif args.adapter == "lokr":
152
        config = LoKrConfig(
153
            r=args.te_r,
154
            alpha=args.te_alpha,
155
            target_modules=TEXT_ENCODER_TARGET_MODULES,
156
            rank_dropout=args.te_rank_dropout,
157
            module_dropout=args.te_module_dropout,
158
            decompose_both=args.te_decompose_both,
159
            decompose_factor=args.te_decompose_factor,
160
            init_weights=True,
161
        )
162
    else:
163
        raise ValueError(f"Unknown adapter type {args.adapter}")
164

165
    return config
166

167

168
def parse_args(input_args=None):
169
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
170
    parser.add_argument(
171
        "--pretrained_model_name_or_path",
172
        type=str,
173
        default=None,
174
        required=True,
175
        help="Path to pretrained model or model identifier from huggingface.co/models.",
176
    )
177
    parser.add_argument(
178
        "--revision",
179
        type=str,
180
        default=None,
181
        required=False,
182
        help="Revision of pretrained model identifier from huggingface.co/models.",
183
    )
184
    parser.add_argument(
185
        "--tokenizer_name",
186
        type=str,
187
        default=None,
188
        help="Pretrained tokenizer name or path if not the same as model_name",
189
    )
190
    parser.add_argument(
191
        "--instance_data_dir",
192
        type=str,
193
        default=None,
194
        required=True,
195
        help="A folder containing the training data of instance images.",
196
    )
197
    parser.add_argument(
198
        "--class_data_dir",
199
        type=str,
200
        default=None,
201
        required=False,
202
        help="A folder containing the training data of class images.",
203
    )
204
    parser.add_argument(
205
        "--instance_prompt",
206
        type=str,
207
        default=None,
208
        required=True,
209
        help="The prompt with identifier specifying the instance",
210
    )
211
    parser.add_argument(
212
        "--class_prompt",
213
        type=str,
214
        default=None,
215
        help="The prompt to specify images in the same class as provided instance images.",
216
    )
217
    parser.add_argument(
218
        "--with_prior_preservation",
219
        default=False,
220
        action="store_true",
221
        help="Flag to add prior preservation loss.",
222
    )
223
    parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
224
    parser.add_argument(
225
        "--num_class_images",
226
        type=int,
227
        default=100,
228
        help=(
229
            "Minimal class images for prior preservation loss. If there are not enough images already present in"
230
            " class_data_dir, additional images will be sampled with class_prompt."
231
        ),
232
    )
233
    parser.add_argument(
234
        "--validation_prompt",
235
        type=str,
236
        default=None,
237
        help="A prompt that is used during validation to verify that the model is learning.",
238
    )
239
    parser.add_argument(
240
        "--num_validation_images",
241
        type=int,
242
        default=4,
243
        help="Number of images that should be generated during validation with `validation_prompt`.",
244
    )
245
    parser.add_argument(
246
        "--validation_steps",
247
        type=int,
248
        default=100,
249
        help=(
250
            "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt"
251
            " `args.validation_prompt` multiple times: `args.num_validation_images`."
252
        ),
253
    )
254
    parser.add_argument(
255
        "--output_dir",
256
        type=str,
257
        default="text-inversion-model",
258
        help="The output directory where the model predictions and checkpoints will be written.",
259
    )
260
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
261
    parser.add_argument(
262
        "--resolution",
263
        type=int,
264
        default=512,
265
        help=(
266
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
267
            " resolution"
268
        ),
269
    )
270
    parser.add_argument(
271
        "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
272
    )
273
    parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
274

275
    parser.add_argument(
276
        "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
277
    )
278
    parser.add_argument(
279
        "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
280
    )
281
    parser.add_argument("--num_train_epochs", type=int, default=1)
282
    parser.add_argument(
283
        "--max_train_steps",
284
        type=int,
285
        default=None,
286
        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
287
    )
288
    parser.add_argument(
289
        "--checkpointing_steps",
290
        type=int,
291
        default=500,
292
        help=(
293
            "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
294
            " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
295
            " training using `--resume_from_checkpoint`."
296
        ),
297
    )
298
    parser.add_argument(
299
        "--resume_from_checkpoint",
300
        type=str,
301
        default=None,
302
        help=(
303
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
304
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
305
        ),
306
    )
307
    parser.add_argument(
308
        "--gradient_accumulation_steps",
309
        type=int,
310
        default=1,
311
        help="Number of updates steps to accumulate before performing a backward/update pass.",
312
    )
313
    parser.add_argument(
314
        "--gradient_checkpointing",
315
        action="store_true",
316
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
317
    )
318
    parser.add_argument(
319
        "--learning_rate",
320
        type=float,
321
        default=5e-6,
322
        help="Initial learning rate (after the potential warmup period) to use.",
323
    )
324
    parser.add_argument(
325
        "--scale_lr",
326
        action="store_true",
327
        default=False,
328
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
329
    )
330
    parser.add_argument(
331
        "--lr_scheduler",
332
        type=str,
333
        default="constant",
334
        help=(
335
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
336
            ' "constant", "constant_with_warmup"]'
337
        ),
338
    )
339
    parser.add_argument(
340
        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
341
    )
342
    parser.add_argument(
343
        "--lr_num_cycles",
344
        type=int,
345
        default=1,
346
        help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
347
    )
348
    parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
349
    parser.add_argument(
350
        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
351
    )
352
    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
353
    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
354
    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
355
    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
356
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
357
    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
358
    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
359
    parser.add_argument(
360
        "--hub_model_id",
361
        type=str,
362
        default=None,
363
        help="The name of the repository to keep in sync with the local `output_dir`.",
364
    )
365
    parser.add_argument(
366
        "--logging_dir",
367
        type=str,
368
        default="logs",
369
        help=(
370
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
371
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
372
        ),
373
    )
374
    parser.add_argument(
375
        "--allow_tf32",
376
        action="store_true",
377
        help=(
378
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
379
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
380
        ),
381
    )
382
    parser.add_argument(
383
        "--report_to",
384
        type=str,
385
        default="tensorboard",
386
        help=(
387
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
388
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
389
        ),
390
    )
391
    parser.add_argument(
392
        "--wandb_key",
393
        type=str,
394
        default=None,
395
        help=("If report to option is set to wandb, api-key for wandb used for login to wandb "),
396
    )
397
    parser.add_argument(
398
        "--wandb_project_name",
399
        type=str,
400
        default=None,
401
        help=("If report to option is set to wandb, project name in wandb for log tracking  "),
402
    )
403
    parser.add_argument(
404
        "--mixed_precision",
405
        type=str,
406
        default=None,
407
        choices=["no", "fp16", "bf16"],
408
        help=(
409
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
410
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
411
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
412
        ),
413
    )
414
    parser.add_argument(
415
        "--prior_generation_precision",
416
        type=str,
417
        default=None,
418
        choices=["no", "fp32", "fp16", "bf16"],
419
        help=(
420
            "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
421
            " 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32."
422
        ),
423
    )
424
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
425
    parser.add_argument(
426
        "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
427
    )
428

429
    # Adapter arguments
430
    subparsers = parser.add_subparsers(dest="adapter")
431

432
    # Dummy subparser to train whole model
433
    subparsers.add_parser("full", help="Train full model without adapters")
434

435
    # LoRA adapter
436
    lora = subparsers.add_parser("lora", help="Use LoRA adapter")
437
    lora.add_argument("--unet_r", type=int, default=8, help="LoRA rank for unet")
438
    lora.add_argument("--unet_alpha", type=int, default=8, help="LoRA alpha for unet")
439
    lora.add_argument("--unet_dropout", type=float, default=0.0, help="LoRA dropout probability for unet")
440
    lora.add_argument(
441
        "--unet_bias",
442
        type=str,
443
        default="none",
444
        help="Bias type for LoRA. Can be 'none', 'all' or 'lora_only'",
445
    )
446
    lora.add_argument(
447
        "--te_r", type=int, default=8, help="LoRA rank for text_encoder, only used if `train_text_encoder` is True"
448
    )
449
    lora.add_argument(
450
        "--te_alpha",
451
        type=int,
452
        default=8,
453
        help="LoRA alpha for text_encoder, only used if `train_text_encoder` is True",
454
    )
455
    lora.add_argument(
456
        "--te_dropout",
457
        type=float,
458
        default=0.0,
459
        help="LoRA dropout probability for text_encoder, only used if `train_text_encoder` is True",
460
    )
461
    lora.add_argument(
462
        "--te_bias",
463
        type=str,
464
        default="none",
465
        help="Bias type for LoRA. Can be 'none', 'all' or 'lora_only', only used if `train_text_encoder` is True",
466
    )
467

468
    # LoHa adapter
469
    loha = subparsers.add_parser("loha", help="Use LoHa adapter")
470
    loha.add_argument("--unet_r", type=int, default=8, help="LoHa rank for unet")
471
    loha.add_argument("--unet_alpha", type=int, default=8, help="LoHa alpha for unet")
472
    loha.add_argument("--unet_rank_dropout", type=float, default=0.0, help="LoHa rank_dropout probability for unet")
473
    loha.add_argument(
474
        "--unet_module_dropout", type=float, default=0.0, help="LoHa module_dropout probability for unet"
475
    )
476
    loha.add_argument(
477
        "--unet_use_effective_conv2d",
478
        action="store_true",
479
        help="Use parameter effective decomposition in unet for Conv2d 3x3 with ksize > 1",
480
    )
481
    loha.add_argument(
482
        "--te_r", type=int, default=8, help="LoHa rank for text_encoder, only used if `train_text_encoder` is True"
483
    )
484
    loha.add_argument(
485
        "--te_alpha",
486
        type=int,
487
        default=8,
488
        help="LoHa alpha for text_encoder, only used if `train_text_encoder` is True",
489
    )
490
    loha.add_argument(
491
        "--te_rank_dropout",
492
        type=float,
493
        default=0.0,
494
        help="LoHa rank_dropout probability for text_encoder, only used if `train_text_encoder` is True",
495
    )
496
    loha.add_argument(
497
        "--te_module_dropout",
498
        type=float,
499
        default=0.0,
500
        help="LoHa module_dropout probability for text_encoder, only used if `train_text_encoder` is True",
501
    )
502

503
    # LoKr adapter
504
    lokr = subparsers.add_parser("lokr", help="Use LoKr adapter")
505
    lokr.add_argument("--unet_r", type=int, default=8, help="LoKr rank for unet")
506
    lokr.add_argument("--unet_alpha", type=int, default=8, help="LoKr alpha for unet")
507
    lokr.add_argument("--unet_rank_dropout", type=float, default=0.0, help="LoKr rank_dropout probability for unet")
508
    lokr.add_argument(
509
        "--unet_module_dropout", type=float, default=0.0, help="LoKr module_dropout probability for unet"
510
    )
511
    lokr.add_argument(
512
        "--unet_use_effective_conv2d",
513
        action="store_true",
514
        help="Use parameter effective decomposition in unet for Conv2d 3x3 with ksize > 1",
515
    )
516
    lokr.add_argument(
517
        "--unet_decompose_both", action="store_true", help="Decompose left matrix in kronecker product for unet"
518
    )
519
    lokr.add_argument(
520
        "--unet_decompose_factor", type=int, default=-1, help="Decompose factor in kronecker product for unet"
521
    )
522
    lokr.add_argument(
523
        "--te_r", type=int, default=8, help="LoKr rank for text_encoder, only used if `train_text_encoder` is True"
524
    )
525
    lokr.add_argument(
526
        "--te_alpha",
527
        type=int,
528
        default=8,
529
        help="LoKr alpha for text_encoder, only used if `train_text_encoder` is True",
530
    )
531
    lokr.add_argument(
532
        "--te_rank_dropout",
533
        type=float,
534
        default=0.0,
535
        help="LoKr rank_dropout probability for text_encoder, only used if `train_text_encoder` is True",
536
    )
537
    lokr.add_argument(
538
        "--te_module_dropout",
539
        type=float,
540
        default=0.0,
541
        help="LoKr module_dropout probability for text_encoder, only used if `train_text_encoder` is True",
542
    )
543
    lokr.add_argument(
544
        "--te_decompose_both",
545
        action="store_true",
546
        help="Decompose left matrix in kronecker product for text_encoder, only used if `train_text_encoder` is True",
547
    )
548
    lokr.add_argument(
549
        "--te_decompose_factor",
550
        type=int,
551
        default=-1,
552
        help="Decompose factor in kronecker product for text_encoder, only used if `train_text_encoder` is True",
553
    )
554

555
    if input_args is not None:
556
        args = parser.parse_args(input_args)
557
    else:
558
        args = parser.parse_args()
559

560
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
561
    if env_local_rank != -1 and env_local_rank != args.local_rank:
562
        args.local_rank = env_local_rank
563

564
    if args.with_prior_preservation:
565
        if args.class_data_dir is None:
566
            raise ValueError("You must specify a data directory for class images.")
567
        if args.class_prompt is None:
568
            raise ValueError("You must specify prompt for class images.")
569
    else:
570
        # logger is not available yet
571
        if args.class_data_dir is not None:
572
            warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
573
        if args.class_prompt is not None:
574
            warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
575

576
    return args
577

578

579
# Converting Bytes to Megabytes
580
def b2mb(x):
581
    return int(x / 2**20)
582

583

584
# This context manager is used to track the peak memory usage of the process
585
class TorchTracemalloc:
586
    def __enter__(self):
587
        gc.collect()
588
        torch.cuda.empty_cache()
589
        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
590
        self.begin = torch.cuda.memory_allocated()
591
        self.process = psutil.Process()
592

593
        self.cpu_begin = self.cpu_mem_used()
594
        self.peak_monitoring = True
595
        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
596
        peak_monitor_thread.daemon = True
597
        peak_monitor_thread.start()
598
        return self
599

600
    def cpu_mem_used(self):
601
        """get resident set size memory for the current process"""
602
        return self.process.memory_info().rss
603

604
    def peak_monitor_func(self):
605
        self.cpu_peak = -1
606

607
        while True:
608
            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
609

610
            # can't sleep or will not catch the peak right (this comment is here on purpose)
611
            # time.sleep(0.001) # 1msec
612

613
            if not self.peak_monitoring:
614
                break
615

616
    def __exit__(self, *exc):
617
        self.peak_monitoring = False
618

619
        gc.collect()
620
        torch.cuda.empty_cache()
621
        self.end = torch.cuda.memory_allocated()
622
        self.peak = torch.cuda.max_memory_allocated()
623
        self.used = b2mb(self.end - self.begin)
624
        self.peaked = b2mb(self.peak - self.begin)
625

626
        self.cpu_end = self.cpu_mem_used()
627
        self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
628
        self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
629
        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
630

631

632
class DreamBoothDataset(Dataset):
633
    """
634
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
635
    It pre-processes the images and the tokenizes prompts.
636
    """
637

638
    def __init__(
639
        self,
640
        instance_data_root,
641
        instance_prompt,
642
        tokenizer,
643
        class_data_root=None,
644
        class_prompt=None,
645
        size=512,
646
        center_crop=False,
647
    ):
648
        self.size = size
649
        self.center_crop = center_crop
650
        self.tokenizer = tokenizer
651

652
        self.instance_data_root = Path(instance_data_root)
653
        if not self.instance_data_root.exists():
654
            raise ValueError("Instance images root doesn't exists.")
655

656
        self.instance_images_path = list(Path(instance_data_root).iterdir())
657
        self.num_instance_images = len(self.instance_images_path)
658
        self.instance_prompt = instance_prompt
659
        self._length = self.num_instance_images
660

661
        if class_data_root is not None:
662
            self.class_data_root = Path(class_data_root)
663
            self.class_data_root.mkdir(parents=True, exist_ok=True)
664
            self.class_images_path = list(self.class_data_root.iterdir())
665
            self.num_class_images = len(self.class_images_path)
666
            self._length = max(self.num_class_images, self.num_instance_images)
667
            self.class_prompt = class_prompt
668
        else:
669
            self.class_data_root = None
670

671
        self.image_transforms = transforms.Compose(
672
            [
673
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
674
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
675
                transforms.ToTensor(),
676
                transforms.Normalize([0.5], [0.5]),
677
            ]
678
        )
679

680
    def __len__(self):
681
        return self._length
682

683
    def __getitem__(self, index):
684
        example = {}
685
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
686
        if not instance_image.mode == "RGB":
687
            instance_image = instance_image.convert("RGB")
688
        example["instance_images"] = self.image_transforms(instance_image)
689
        example["instance_prompt_ids"] = self.tokenizer(
690
            self.instance_prompt,
691
            truncation=True,
692
            padding="max_length",
693
            max_length=self.tokenizer.model_max_length,
694
            return_tensors="pt",
695
        ).input_ids
696

697
        if self.class_data_root:
698
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
699
            if not class_image.mode == "RGB":
700
                class_image = class_image.convert("RGB")
701
            example["class_images"] = self.image_transforms(class_image)
702
            example["class_prompt_ids"] = self.tokenizer(
703
                self.class_prompt,
704
                truncation=True,
705
                padding="max_length",
706
                max_length=self.tokenizer.model_max_length,
707
                return_tensors="pt",
708
            ).input_ids
709

710
        return example
711

712

713
def collate_fn(examples, with_prior_preservation=False):
714
    input_ids = [example["instance_prompt_ids"] for example in examples]
715
    pixel_values = [example["instance_images"] for example in examples]
716

717
    # Concat class and instance examples for prior preservation.
718
    # We do this to avoid doing two forward passes.
719
    if with_prior_preservation:
720
        input_ids += [example["class_prompt_ids"] for example in examples]
721
        pixel_values += [example["class_images"] for example in examples]
722

723
    pixel_values = torch.stack(pixel_values)
724
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
725

726
    input_ids = torch.cat(input_ids, dim=0)
727

728
    batch = {
729
        "input_ids": input_ids,
730
        "pixel_values": pixel_values,
731
    }
732
    return batch
733

734

735
class PromptDataset(Dataset):
736
    "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
737

738
    def __init__(self, prompt, num_samples):
739
        self.prompt = prompt
740
        self.num_samples = num_samples
741

742
    def __len__(self):
743
        return self.num_samples
744

745
    def __getitem__(self, index):
746
        example = {}
747
        example["prompt"] = self.prompt
748
        example["index"] = index
749
        return example
750

751

752
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
753
    if token is None:
754
        token = HfFolder.get_token()
755
    if organization is None:
756
        username = whoami(token)["name"]
757
        return f"{username}/{model_id}"
758
    else:
759
        return f"{organization}/{model_id}"
760

761

762
def main(args):
763
    logging_dir = Path(args.output_dir, args.logging_dir)
764

765
    accelerator = Accelerator(
766
        gradient_accumulation_steps=args.gradient_accumulation_steps,
767
        mixed_precision=args.mixed_precision,
768
        log_with=args.report_to,
769
        project_dir=logging_dir,
770
    )
771
    if args.report_to == "wandb":
772
        import wandb
773

774
        wandb.login(key=args.wandb_key)
775
        wandb.init(project=args.wandb_project_name)
776
    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
777
    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
778
    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
779
    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
780
        raise ValueError(
781
            "Gradient accumulation is not supported when training the text encoder in distributed training. "
782
            "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
783
        )
784

785
    # Make one log on every process with the configuration for debugging.
786
    logging.basicConfig(
787
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
788
        datefmt="%m/%d/%Y %H:%M:%S",
789
        level=logging.INFO,
790
    )
791
    logger.info(accelerator.state, main_process_only=False)
792
    if accelerator.is_local_main_process:
793
        datasets.utils.logging.set_verbosity_warning()
794
        transformers.utils.logging.set_verbosity_warning()
795
        diffusers.utils.logging.set_verbosity_info()
796
    else:
797
        datasets.utils.logging.set_verbosity_error()
798
        transformers.utils.logging.set_verbosity_error()
799
        diffusers.utils.logging.set_verbosity_error()
800

801
    # If passed along, set the training seed now.
802
    if args.seed is not None:
803
        set_seed(args.seed)
804

805
    # Generate class images if prior preservation is enabled.
806
    if args.with_prior_preservation:
807
        class_images_dir = Path(args.class_data_dir)
808
        if not class_images_dir.exists():
809
            class_images_dir.mkdir(parents=True)
810
        cur_class_images = len(list(class_images_dir.iterdir()))
811

812
        if cur_class_images < args.num_class_images:
813
            torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
814
            if args.prior_generation_precision == "fp32":
815
                torch_dtype = torch.float32
816
            elif args.prior_generation_precision == "fp16":
817
                torch_dtype = torch.float16
818
            elif args.prior_generation_precision == "bf16":
819
                torch_dtype = torch.bfloat16
820
            pipeline = DiffusionPipeline.from_pretrained(
821
                args.pretrained_model_name_or_path,
822
                torch_dtype=torch_dtype,
823
                safety_checker=None,
824
                revision=args.revision,
825
            )
826
            pipeline.set_progress_bar_config(disable=True)
827

828
            num_new_images = args.num_class_images - cur_class_images
829
            logger.info(f"Number of class images to sample: {num_new_images}.")
830

831
            sample_dataset = PromptDataset(args.class_prompt, num_new_images)
832
            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
833

834
            sample_dataloader = accelerator.prepare(sample_dataloader)
835
            pipeline.to(accelerator.device)
836

837
            for example in tqdm(
838
                sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
839
            ):
840
                images = pipeline(example["prompt"]).images
841

842
                for i, image in enumerate(images):
843
                    hash_image = hashlib.sha1(image.tobytes()).hexdigest()
844
                    image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
845
                    image.save(image_filename)
846

847
            del pipeline
848
            if torch.cuda.is_available():
849
                torch.cuda.empty_cache()
850

851
    # Handle the repository creation
852
    if accelerator.is_main_process:
853
        if args.push_to_hub:
854
            if args.hub_model_id is None:
855
                repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
856
            else:
857
                repo_name = args.hub_model_id
858
            repo = Repository(args.output_dir, clone_from=repo_name)  # noqa: F841
859

860
            with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
861
                if "step_*" not in gitignore:
862
                    gitignore.write("step_*\n")
863
                if "epoch_*" not in gitignore:
864
                    gitignore.write("epoch_*\n")
865
        elif args.output_dir is not None:
866
            os.makedirs(args.output_dir, exist_ok=True)
867

868
    # Load the tokenizer
869
    if args.tokenizer_name:
870
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
871
    elif args.pretrained_model_name_or_path:
872
        tokenizer = AutoTokenizer.from_pretrained(
873
            args.pretrained_model_name_or_path,
874
            subfolder="tokenizer",
875
            revision=args.revision,
876
            use_fast=False,
877
        )
878

879
    # import correct text encoder class
880
    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
881

882
    # Load scheduler and models
883
    noise_scheduler = DDPMScheduler(
884
        beta_start=0.00085,
885
        beta_end=0.012,
886
        beta_schedule="scaled_linear",
887
        num_train_timesteps=1000,
888
    )  # DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
889
    text_encoder = text_encoder_cls.from_pretrained(
890
        args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
891
    )
892
    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
893
    unet = UNet2DConditionModel.from_pretrained(
894
        args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
895
    )
896

897
    if args.adapter != "full":
898
        config = create_unet_adapter_config(args)
899
        unet = get_peft_model(unet, config)
900
        unet.print_trainable_parameters()
901
        print(unet)
902

903
    vae.requires_grad_(False)
904
    if not args.train_text_encoder:
905
        text_encoder.requires_grad_(False)
906
    elif args.train_text_encoder and args.adapter != "full":
907
        config = create_text_encoder_adapter_config(args)
908
        text_encoder = get_peft_model(text_encoder, config)
909
        text_encoder.print_trainable_parameters()
910
        print(text_encoder)
911

912
    if args.enable_xformers_memory_efficient_attention:
913
        if is_xformers_available():
914
            unet.enable_xformers_memory_efficient_attention()
915
        else:
916
            raise ValueError("xformers is not available. Make sure it is installed correctly")
917

918
    if args.gradient_checkpointing:
919
        unet.enable_gradient_checkpointing()
920
        if args.train_text_encoder and not args.adapter != "full":
921
            text_encoder.gradient_checkpointing_enable()
922

923
    # Enable TF32 for faster training on Ampere GPUs,
924
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
925
    if args.allow_tf32:
926
        torch.backends.cuda.matmul.allow_tf32 = True
927

928
    if args.scale_lr:
929
        args.learning_rate = (
930
            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
931
        )
932

933
    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
934
    if args.use_8bit_adam:
935
        try:
936
            import bitsandbytes as bnb
937
        except ImportError:
938
            raise ImportError(
939
                "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
940
            )
941

942
        optimizer_class = bnb.optim.AdamW8bit
943
    else:
944
        optimizer_class = torch.optim.AdamW
945

946
    # Optimizer creation
947
    params_to_optimize = (
948
        itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
949
    )
950
    optimizer = optimizer_class(
951
        params_to_optimize,
952
        lr=args.learning_rate,
953
        betas=(args.adam_beta1, args.adam_beta2),
954
        weight_decay=args.adam_weight_decay,
955
        eps=args.adam_epsilon,
956
    )
957

958
    # Dataset and DataLoaders creation:
959
    train_dataset = DreamBoothDataset(
960
        instance_data_root=args.instance_data_dir,
961
        instance_prompt=args.instance_prompt,
962
        class_data_root=args.class_data_dir if args.with_prior_preservation else None,
963
        class_prompt=args.class_prompt,
964
        tokenizer=tokenizer,
965
        size=args.resolution,
966
        center_crop=args.center_crop,
967
    )
968

969
    train_dataloader = torch.utils.data.DataLoader(
970
        train_dataset,
971
        batch_size=args.train_batch_size,
972
        shuffle=True,
973
        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
974
        num_workers=1,
975
    )
976

977
    # Scheduler and math around the number of training steps.
978
    overrode_max_train_steps = False
979
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
980
    if args.max_train_steps is None:
981
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
982
        overrode_max_train_steps = True
983

984
    lr_scheduler = get_scheduler(
985
        args.lr_scheduler,
986
        optimizer=optimizer,
987
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
988
        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
989
        num_cycles=args.lr_num_cycles,
990
        power=args.lr_power,
991
    )
992

993
    # Prepare everything with our `accelerator`.
994
    if args.train_text_encoder:
995
        unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
996
            unet, text_encoder, optimizer, train_dataloader, lr_scheduler
997
        )
998
    else:
999
        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1000
            unet, optimizer, train_dataloader, lr_scheduler
1001
        )
1002

1003
    # For mixed precision training we cast the text_encoder and vae weights to half-precision
1004
    # as these models are only used for inference, keeping weights in full precision is not required.
1005
    weight_dtype = torch.float32
1006
    if accelerator.mixed_precision == "fp16":
1007
        weight_dtype = torch.float16
1008
    elif accelerator.mixed_precision == "bf16":
1009
        weight_dtype = torch.bfloat16
1010

1011
    # Move vae and text_encoder to device and cast to weight_dtype
1012
    vae.to(accelerator.device, dtype=weight_dtype)
1013
    if not args.train_text_encoder:
1014
        text_encoder.to(accelerator.device, dtype=weight_dtype)
1015

1016
    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1017
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1018
    if overrode_max_train_steps:
1019
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1020
    # Afterwards we recalculate our number of training epochs
1021
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1022

1023
    # We need to initialize the trackers we use, and also store our configuration.
1024
    # The trackers initializes automatically on the main process.
1025
    if accelerator.is_main_process:
1026
        accelerator.init_trackers("dreambooth", config=vars(args))
1027

1028
    # Train!
1029
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1030

1031
    logger.info("***** Running training *****")
1032
    logger.info(f"  Num examples = {len(train_dataset)}")
1033
    logger.info(f"  Num batches each epoch = {len(train_dataloader)}")
1034
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
1035
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
1036
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1037
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1038
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
1039
    global_step = 0
1040
    first_epoch = 0
1041

1042
    # Potentially load in the weights and states from a previous save
1043
    if args.resume_from_checkpoint:
1044
        if args.resume_from_checkpoint != "latest":
1045
            path = os.path.basename(args.resume_from_checkpoint)
1046
        else:
1047
            # Get the mos recent checkpoint
1048
            dirs = os.listdir(args.output_dir)
1049
            dirs = [d for d in dirs if d.startswith("checkpoint")]
1050
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1051
            path = dirs[-1]
1052
        accelerator.print(f"Resuming from checkpoint {path}")
1053
        accelerator.load_state(os.path.join(args.output_dir, path))
1054
        global_step = int(path.split("-")[1])
1055

1056
        resume_global_step = global_step * args.gradient_accumulation_steps
1057
        first_epoch = resume_global_step // num_update_steps_per_epoch
1058
        resume_step = resume_global_step % num_update_steps_per_epoch
1059

1060
    # Only show the progress bar once on each machine.
1061
    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
1062
    progress_bar.set_description("Steps")
1063

1064
    for epoch in range(first_epoch, args.num_train_epochs):
1065
        unet.train()
1066
        if args.train_text_encoder:
1067
            text_encoder.train()
1068
        with TorchTracemalloc() as tracemalloc:
1069
            for step, batch in enumerate(train_dataloader):
1070
                # Skip steps until we reach the resumed step
1071
                if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1072
                    if step % args.gradient_accumulation_steps == 0:
1073
                        progress_bar.update(1)
1074
                        if args.report_to == "wandb":
1075
                            accelerator.print(progress_bar)
1076
                    continue
1077

1078
                with accelerator.accumulate(unet):
1079
                    # Convert images to latent space
1080
                    latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1081
                    latents = latents * 0.18215
1082

1083
                    # Sample noise that we'll add to the latents
1084
                    noise = torch.randn_like(latents)
1085
                    bsz = latents.shape[0]
1086
                    # Sample a random timestep for each image
1087
                    timesteps = torch.randint(
1088
                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
1089
                    )
1090
                    timesteps = timesteps.long()
1091

1092
                    # Add noise to the latents according to the noise magnitude at each timestep
1093
                    # (this is the forward diffusion process)
1094
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1095

1096
                    # Get the text embedding for conditioning
1097
                    encoder_hidden_states = text_encoder(batch["input_ids"])[0]
1098

1099
                    # Predict the noise residual
1100
                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1101

1102
                    # Get the target for loss depending on the prediction type
1103
                    if noise_scheduler.config.prediction_type == "epsilon":
1104
                        target = noise
1105
                    elif noise_scheduler.config.prediction_type == "v_prediction":
1106
                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
1107
                    else:
1108
                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1109

1110
                    if args.with_prior_preservation:
1111
                        # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1112
                        model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1113
                        target, target_prior = torch.chunk(target, 2, dim=0)
1114

1115
                        # Compute instance loss
1116
                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1117

1118
                        # Compute prior loss
1119
                        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1120

1121
                        # Add the prior loss to the instance loss.
1122
                        loss = loss + args.prior_loss_weight * prior_loss
1123
                    else:
1124
                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1125

1126
                    accelerator.backward(loss)
1127
                    if accelerator.sync_gradients:
1128
                        params_to_clip = (
1129
                            itertools.chain(unet.parameters(), text_encoder.parameters())
1130
                            if args.train_text_encoder
1131
                            else unet.parameters()
1132
                        )
1133
                        accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1134
                    optimizer.step()
1135
                    lr_scheduler.step()
1136
                    optimizer.zero_grad()
1137

1138
                # Checks if the accelerator has performed an optimization step behind the scenes
1139
                if accelerator.sync_gradients:
1140
                    progress_bar.update(1)
1141
                    if args.report_to == "wandb":
1142
                        accelerator.print(progress_bar)
1143
                    global_step += 1
1144

1145
                    # if global_step % args.checkpointing_steps == 0:
1146
                    #     if accelerator.is_main_process:
1147
                    #         save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1148
                    #         accelerator.save_state(save_path)
1149
                    #         logger.info(f"Saved state to {save_path}")
1150

1151
                logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1152
                progress_bar.set_postfix(**logs)
1153
                accelerator.log(logs, step=global_step)
1154

1155
                if (
1156
                    args.validation_prompt is not None
1157
                    and (step + num_update_steps_per_epoch * epoch) % args.validation_steps == 0
1158
                ):
1159
                    logger.info(
1160
                        f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1161
                        f" {args.validation_prompt}."
1162
                    )
1163
                    # create pipeline
1164
                    pipeline = DiffusionPipeline.from_pretrained(
1165
                        args.pretrained_model_name_or_path,
1166
                        safety_checker=None,
1167
                        revision=args.revision,
1168
                    )
1169
                    # set `keep_fp32_wrapper` to True because we do not want to remove
1170
                    # mixed precision hooks while we are still training
1171
                    pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
1172
                    pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
1173
                    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1174
                    pipeline = pipeline.to(accelerator.device)
1175
                    pipeline.set_progress_bar_config(disable=True)
1176

1177
                    # Set evaliation mode
1178
                    pipeline.unet.eval()
1179
                    pipeline.text_encoder.eval()
1180

1181
                    # run inference
1182
                    if args.seed is not None:
1183
                        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
1184
                    else:
1185
                        generator = None
1186
                    images = []
1187
                    for _ in range(args.num_validation_images):
1188
                        image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1189
                        images.append(image)
1190

1191
                    for tracker in accelerator.trackers:
1192
                        if tracker.name == "tensorboard":
1193
                            np_images = np.stack([np.asarray(img) for img in images])
1194
                            tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1195
                        if tracker.name == "wandb":
1196
                            import wandb
1197

1198
                            tracker.log(
1199
                                {
1200
                                    "validation": [
1201
                                        wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1202
                                        for i, image in enumerate(images)
1203
                                    ]
1204
                                }
1205
                            )
1206

1207
                    # Set evaliation mode
1208
                    pipeline.unet.train()
1209
                    pipeline.text_encoder.train()
1210

1211
                    del pipeline
1212
                    torch.cuda.empty_cache()
1213

1214
                if global_step >= args.max_train_steps:
1215
                    break
1216
        # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
1217
        accelerator.print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}")
1218
        accelerator.print(f"GPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}")
1219
        accelerator.print(f"GPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}")
1220
        accelerator.print(
1221
            f"GPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}"
1222
        )
1223

1224
        accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}")
1225
        accelerator.print(f"CPU Memory consumed at the end of the train (end-begin): {tracemalloc.cpu_used}")
1226
        accelerator.print(f"CPU Peak Memory consumed during the train (max-begin): {tracemalloc.cpu_peaked}")
1227
        accelerator.print(
1228
            "CPU Total Peak Memory consumed during the train (max): {}".format(
1229
                tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)
1230
            )
1231
        )
1232

1233
    # Create the pipeline using using the trained modules and save it.
1234
    accelerator.wait_for_everyone()
1235
    if accelerator.is_main_process:
1236
        if args.adapter != "full":
1237
            unwarpped_unet = accelerator.unwrap_model(unet)
1238
            unwarpped_unet.save_pretrained(
1239
                os.path.join(args.output_dir, "unet"), state_dict=accelerator.get_state_dict(unet)
1240
            )
1241
            if args.train_text_encoder:
1242
                unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
1243
                unwarpped_text_encoder.save_pretrained(
1244
                    os.path.join(args.output_dir, "text_encoder"),
1245
                    state_dict=accelerator.get_state_dict(text_encoder),
1246
                )
1247
        else:
1248
            pipeline = DiffusionPipeline.from_pretrained(
1249
                args.pretrained_model_name_or_path,
1250
                unet=accelerator.unwrap_model(unet),
1251
                text_encoder=accelerator.unwrap_model(text_encoder),
1252
                revision=args.revision,
1253
            )
1254
            pipeline.save_pretrained(args.output_dir)
1255

1256
        if args.push_to_hub:
1257
            repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1258

1259
    accelerator.end_training()
1260

1261

1262
if __name__ == "__main__":
1263
    args = parse_args()
1264
    main(args)
1265

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.