optimum-habana

Форк
0
/
test_diffusers.py 
1950 строк · 68.7 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team.
3
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16

17
import json
18
import os
19
import re
20
import subprocess
21
import tempfile
22
from io import BytesIO
23
from pathlib import Path
24
from unittest import TestCase, skipUnless
25

26
import numpy as np
27
import requests
28
import torch
29
from diffusers import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UniPCMultistepScheduler
30
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
31
from diffusers.utils import load_image
32
from diffusers.utils.torch_utils import randn_tensor
33
from huggingface_hub import snapshot_download
34
from parameterized import parameterized
35
from PIL import Image
36
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
37
from transformers.testing_utils import parse_flag_from_env, slow
38

39
from optimum.habana import GaudiConfig
40
from optimum.habana.diffusers import (
41
    GaudiDDIMScheduler,
42
    GaudiDiffusionPipeline,
43
    GaudiEulerAncestralDiscreteScheduler,
44
    GaudiEulerDiscreteScheduler,
45
    GaudiStableDiffusionControlNetPipeline,
46
    GaudiStableDiffusionLDM3DPipeline,
47
    GaudiStableDiffusionPipeline,
48
    GaudiStableDiffusionUpscalePipeline,
49
    GaudiStableDiffusionXLPipeline,
50
)
51
from optimum.habana.utils import set_seed
52

53
from .clip_coco_utils import download_files
54

55

56
if os.environ.get("GAUDI2_CI", "0") == "1":
57
    THROUGHPUT_BASELINE_BF16 = 1.016
58
    THROUGHPUT_BASELINE_AUTOCAST = 0.394
59
    TEXTUAL_INVERSION_THROUGHPUT = 104.29806
60
    TEXTUAL_INVERSION_RUNTIME = 114.1344320399221
61
    CONTROLNET_THROUGHPUT = 92.886919836857
62
    CONTROLNET_RUNTIME = 537.4276602957398
63
else:
64
    THROUGHPUT_BASELINE_BF16 = 0.309
65
    THROUGHPUT_BASELINE_AUTOCAST = 0.114
66
    TEXTUAL_INVERSION_THROUGHPUT = 58.17508958300077
67
    TEXTUAL_INVERSION_RUNTIME = 202.94231038199996
68
    CONTROLNET_THROUGHPUT = 44.412012818816905
69
    CONTROLNET_RUNTIME = 1124.0202105600001
70

71

72
_run_custom_bf16_ops_test_ = parse_flag_from_env("CUSTOM_BF16_OPS", default=False)
73

74

75
def custom_bf16_ops(test_case):
76
    """
77
    Decorator marking a test as needing custom bf16 ops.
78
    Custom bf16 ops must be declared before `habana_frameworks.torch.core` is imported, which is not possible if some other tests are executed before.
79

80
    Such tests are skipped by default. Set the CUSTOM_BF16_OPS environment variable to a truthy value to run them.
81

82
    """
83
    return skipUnless(_run_custom_bf16_ops_test_, "test requires custom bf16 ops")(test_case)
84

85

86
class GaudiPipelineUtilsTester(TestCase):
87
    """
88
    Tests the features added on top of diffusers/pipeline_utils.py.
89
    """
90

91
    def test_use_hpu_graphs_raise_error_without_habana(self):
92
        with self.assertRaises(ValueError):
93
            _ = GaudiDiffusionPipeline(
94
                use_habana=False,
95
                use_hpu_graphs=True,
96
            )
97

98
    def test_gaudi_config_raise_error_without_habana(self):
99
        with self.assertRaises(ValueError):
100
            _ = GaudiDiffusionPipeline(
101
                use_habana=False,
102
                gaudi_config=GaudiConfig(),
103
            )
104

105
    def test_device(self):
106
        pipeline_1 = GaudiDiffusionPipeline(
107
            use_habana=True,
108
            gaudi_config=GaudiConfig(),
109
        )
110
        self.assertEqual(pipeline_1._device.type, "hpu")
111

112
        pipeline_2 = GaudiDiffusionPipeline(
113
            use_habana=False,
114
        )
115
        self.assertEqual(pipeline_2._device.type, "cpu")
116

117
    def test_gaudi_config_types(self):
118
        # gaudi_config is a string
119
        _ = GaudiDiffusionPipeline(
120
            use_habana=True,
121
            gaudi_config="Habana/stable-diffusion",
122
        )
123

124
        # gaudi_config is instantiated beforehand
125
        gaudi_config = GaudiConfig.from_pretrained("Habana/stable-diffusion")
126
        _ = GaudiDiffusionPipeline(
127
            use_habana=True,
128
            gaudi_config=gaudi_config,
129
        )
130

131
    def test_default(self):
132
        pipeline = GaudiDiffusionPipeline(
133
            use_habana=True,
134
            gaudi_config=GaudiConfig(),
135
        )
136

137
        self.assertTrue(hasattr(pipeline, "htcore"))
138

139
    def test_use_hpu_graphs(self):
140
        pipeline = GaudiDiffusionPipeline(
141
            use_habana=True,
142
            use_hpu_graphs=True,
143
            gaudi_config=GaudiConfig(),
144
        )
145

146
        self.assertTrue(hasattr(pipeline, "ht"))
147
        self.assertTrue(hasattr(pipeline, "hpu_stream"))
148
        self.assertTrue(hasattr(pipeline, "cache"))
149

150
    def test_save_pretrained(self):
151
        model_name = "hf-internal-testing/tiny-stable-diffusion-torch"
152
        scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
153
        pipeline = GaudiStableDiffusionPipeline.from_pretrained(
154
            model_name,
155
            scheduler=scheduler,
156
            use_habana=True,
157
            gaudi_config=GaudiConfig(),
158
        )
159

160
        with tempfile.TemporaryDirectory() as tmp_dir:
161
            pipeline.save_pretrained(tmp_dir)
162
            self.assertTrue(Path(tmp_dir, "gaudi_config.json").is_file())
163

164

165
class GaudiStableDiffusionPipelineTester(TestCase):
166
    """
167
    Tests the StableDiffusionPipeline for Gaudi.
168
    """
169

170
    def get_dummy_components(self, time_cond_proj_dim=None):
171
        torch.manual_seed(0)
172
        unet = UNet2DConditionModel(
173
            block_out_channels=(4, 8),
174
            layers_per_block=1,
175
            sample_size=32,
176
            time_cond_proj_dim=time_cond_proj_dim,
177
            in_channels=4,
178
            out_channels=4,
179
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
180
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
181
            cross_attention_dim=32,
182
            norm_num_groups=2,
183
        )
184
        scheduler = GaudiDDIMScheduler(
185
            beta_start=0.00085,
186
            beta_end=0.012,
187
            beta_schedule="scaled_linear",
188
            clip_sample=False,
189
            set_alpha_to_one=False,
190
        )
191
        torch.manual_seed(0)
192
        vae = AutoencoderKL(
193
            block_out_channels=[4, 8],
194
            in_channels=3,
195
            out_channels=3,
196
            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
197
            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
198
            latent_channels=4,
199
            norm_num_groups=2,
200
        )
201
        torch.manual_seed(0)
202
        text_encoder_config = CLIPTextConfig(
203
            bos_token_id=0,
204
            eos_token_id=2,
205
            hidden_size=32,
206
            intermediate_size=64,
207
            layer_norm_eps=1e-05,
208
            num_attention_heads=8,
209
            num_hidden_layers=3,
210
            pad_token_id=1,
211
            vocab_size=1000,
212
        )
213
        text_encoder = CLIPTextModel(text_encoder_config)
214
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
215

216
        components = {
217
            "unet": unet,
218
            "scheduler": scheduler,
219
            "vae": vae,
220
            "text_encoder": text_encoder,
221
            "tokenizer": tokenizer,
222
            "safety_checker": None,
223
            "feature_extractor": None,
224
        }
225
        return components
226

227
    def get_dummy_inputs(self, device, seed=0):
228
        generator = torch.Generator(device=device).manual_seed(seed)
229
        inputs = {
230
            "prompt": "A painting of a squirrel eating a burger",
231
            "generator": generator,
232
            "num_inference_steps": 2,
233
            "guidance_scale": 6.0,
234
            "output_type": "numpy",
235
        }
236
        return inputs
237

238
    def test_stable_diffusion_ddim(self):
239
        device = "cpu"
240

241
        components = self.get_dummy_components()
242
        gaudi_config = GaudiConfig(use_torch_autocast=False)
243

244
        sd_pipe = GaudiStableDiffusionPipeline(
245
            use_habana=True,
246
            gaudi_config=gaudi_config,
247
            **components,
248
        )
249
        sd_pipe.set_progress_bar_config(disable=None)
250

251
        inputs = self.get_dummy_inputs(device)
252
        output = sd_pipe(**inputs)
253
        image = output.images[0]
254

255
        image_slice = image[-3:, -3:, -1]
256

257
        self.assertEqual(image.shape, (64, 64, 3))
258
        expected_slice = np.array([0.3203, 0.4555, 0.4711, 0.3505, 0.3973, 0.4650, 0.5137, 0.3392, 0.4045])
259

260
        self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2)
261

262
    def test_stable_diffusion_no_safety_checker(self):
263
        gaudi_config = GaudiConfig()
264
        scheduler = GaudiDDIMScheduler(
265
            beta_start=0.00085,
266
            beta_end=0.012,
267
            beta_schedule="scaled_linear",
268
            clip_sample=False,
269
            set_alpha_to_one=False,
270
        )
271
        pipe = GaudiStableDiffusionPipeline.from_pretrained(
272
            "hf-internal-testing/tiny-stable-diffusion-pipe",
273
            scheduler=scheduler,
274
            safety_checker=None,
275
            use_habana=True,
276
            gaudi_config=gaudi_config,
277
        )
278
        self.assertIsInstance(pipe, GaudiStableDiffusionPipeline)
279
        self.assertIsInstance(pipe.scheduler, GaudiDDIMScheduler)
280
        self.assertIsNone(pipe.safety_checker)
281

282
        image = pipe("example prompt", num_inference_steps=2).images[0]
283
        self.assertIsNotNone(image)
284

285
        # Check that there's no error when saving a pipeline with one of the models being None
286
        with tempfile.TemporaryDirectory() as tmpdirname:
287
            pipe.save_pretrained(tmpdirname)
288
            pipe = GaudiStableDiffusionPipeline.from_pretrained(
289
                tmpdirname,
290
                use_habana=True,
291
                gaudi_config=tmpdirname,
292
            )
293

294
        # Sanity check that the pipeline still works
295
        self.assertIsNone(pipe.safety_checker)
296
        image = pipe("example prompt", num_inference_steps=2).images[0]
297
        self.assertIsNotNone(image)
298

299
    @parameterized.expand(["pil", "np", "latent"])
300
    def test_stable_diffusion_output_types(self, output_type):
301
        components = self.get_dummy_components()
302
        gaudi_config = GaudiConfig()
303

304
        sd_pipe = GaudiStableDiffusionPipeline(
305
            use_habana=True,
306
            gaudi_config=gaudi_config,
307
            **components,
308
        )
309
        sd_pipe.set_progress_bar_config(disable=None)
310

311
        prompt = "A painting of a squirrel eating a burger"
312
        num_prompts = 2
313
        num_images_per_prompt = 3
314

315
        outputs = sd_pipe(
316
            num_prompts * [prompt],
317
            num_images_per_prompt=num_images_per_prompt,
318
            num_inference_steps=2,
319
            output_type=output_type,
320
        )
321

322
        self.assertEqual(len(outputs.images), 2 * 3)
323
        # TODO: enable safety checker
324
        # if output_type == "latent":
325
        #     self.assertIsNone(outputs.nsfw_content_detected)
326
        # else:
327
        #     self.assertEqual(len(outputs.nsfw_content_detected), 2 * 3)
328

329
    # TODO: enable this test when PNDMScheduler is adapted to Gaudi
330
    # def test_stable_diffusion_negative_prompt(self):
331
    #     device = "cpu"  # ensure determinism for the device-dependent torch.Generator
332
    #     unet = self.dummy_cond_unet
333
    #     scheduler = PNDMScheduler(skip_prk_steps=True)
334
    #     vae = self.dummy_vae
335
    #     bert = self.dummy_text_encoder
336
    #     tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
337

338
    #     # make sure here that pndm scheduler skips prk
339
    #     sd_pipe = StableDiffusionPipeline(
340
    #         unet=unet,
341
    #         scheduler=scheduler,
342
    #         vae=vae,
343
    #         text_encoder=bert,
344
    #         tokenizer=tokenizer,
345
    #         safety_checker=None,
346
    #         feature_extractor=self.dummy_extractor,
347
    #     )
348
    #     sd_pipe = sd_pipe.to(device)
349
    #     sd_pipe.set_progress_bar_config(disable=None)
350

351
    #     prompt = "A painting of a squirrel eating a burger"
352
    #     negative_prompt = "french fries"
353
    #     generator = torch.Generator(device=device).manual_seed(0)
354
    #     output = sd_pipe(
355
    #         prompt,
356
    #         negative_prompt=negative_prompt,
357
    #         generator=generator,
358
    #         guidance_scale=6.0,
359
    #         num_inference_steps=2,
360
    #         output_type="np",
361
    #     )
362

363
    #     image = output.images
364
    #     image_slice = image[0, -3:, -3:, -1]
365

366
    #     assert image.shape == (1, 128, 128, 3)
367
    #     expected_slice = np.array([0.4851, 0.4617, 0.4765, 0.5127, 0.4845, 0.5153, 0.5141, 0.4886, 0.4719])
368
    #     assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
369

370
    def test_stable_diffusion_num_images_per_prompt(self):
371
        components = self.get_dummy_components()
372
        gaudi_config = GaudiConfig()
373

374
        sd_pipe = GaudiStableDiffusionPipeline(
375
            use_habana=True,
376
            gaudi_config=gaudi_config,
377
            **components,
378
        )
379
        sd_pipe.set_progress_bar_config(disable=None)
380

381
        prompt = "A painting of a squirrel eating a burger"
382

383
        # Test num_images_per_prompt=1 (default)
384
        images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images
385

386
        self.assertEqual(len(images), 1)
387
        self.assertEqual(images[0].shape, (64, 64, 3))
388

389
        # Test num_images_per_prompt=1 (default) for several prompts
390
        num_prompts = 3
391
        images = sd_pipe([prompt] * num_prompts, num_inference_steps=2, output_type="np").images
392

393
        self.assertEqual(len(images), num_prompts)
394
        self.assertEqual(images[-1].shape, (64, 64, 3))
395

396
        # Test num_images_per_prompt for single prompt
397
        num_images_per_prompt = 2
398
        images = sd_pipe(
399
            prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
400
        ).images
401

402
        self.assertEqual(len(images), num_images_per_prompt)
403
        self.assertEqual(images[-1].shape, (64, 64, 3))
404

405
        # Test num_images_per_prompt for several prompts
406
        num_prompts = 2
407
        images = sd_pipe(
408
            [prompt] * num_prompts,
409
            num_inference_steps=2,
410
            output_type="np",
411
            num_images_per_prompt=num_images_per_prompt,
412
        ).images
413

414
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
415
        self.assertEqual(images[-1].shape, (64, 64, 3))
416

417
    def test_stable_diffusion_batch_sizes(self):
418
        components = self.get_dummy_components()
419
        gaudi_config = GaudiConfig()
420

421
        sd_pipe = GaudiStableDiffusionPipeline(
422
            use_habana=True,
423
            gaudi_config=gaudi_config,
424
            **components,
425
        )
426
        sd_pipe.set_progress_bar_config(disable=None)
427

428
        prompt = "A painting of a squirrel eating a burger"
429

430
        # Test batch_size > 1 where batch_size is a divider of the total number of generated images
431
        batch_size = 3
432
        num_images_per_prompt = batch_size**2
433
        images = sd_pipe(
434
            prompt,
435
            num_inference_steps=2,
436
            output_type="np",
437
            batch_size=batch_size,
438
            num_images_per_prompt=num_images_per_prompt,
439
        ).images
440

441
        self.assertEqual(len(images), num_images_per_prompt)
442
        self.assertEqual(images[-1].shape, (64, 64, 3))
443

444
        # Same test for several prompts
445
        num_prompts = 3
446
        images = sd_pipe(
447
            [prompt] * num_prompts,
448
            num_inference_steps=2,
449
            output_type="np",
450
            batch_size=batch_size,
451
            num_images_per_prompt=num_images_per_prompt,
452
        ).images
453

454
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
455
        self.assertEqual(images[-1].shape, (64, 64, 3))
456

457
        # Test batch_size when it is not a divider of the toal number of generated images for a single prompt
458
        num_images_per_prompt = 7
459
        images = sd_pipe(
460
            prompt,
461
            num_inference_steps=2,
462
            output_type="np",
463
            batch_size=batch_size,
464
            num_images_per_prompt=num_images_per_prompt,
465
        ).images
466

467
        self.assertEqual(len(images), num_images_per_prompt)
468
        self.assertEqual(images[-1].shape, (64, 64, 3))
469

470
        # Same test for several prompts
471
        num_prompts = 2
472
        images = sd_pipe(
473
            [prompt] * num_prompts,
474
            num_inference_steps=2,
475
            output_type="np",
476
            batch_size=batch_size,
477
            num_images_per_prompt=num_images_per_prompt,
478
        ).images
479

480
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
481
        self.assertEqual(images[-1].shape, (64, 64, 3))
482

483
    def test_stable_diffusion_bf16(self):
484
        """Test that stable diffusion works with bf16"""
485
        components = self.get_dummy_components()
486
        gaudi_config = GaudiConfig()
487

488
        sd_pipe = GaudiStableDiffusionPipeline(
489
            use_habana=True,
490
            gaudi_config=gaudi_config,
491
            **components,
492
        )
493
        sd_pipe.set_progress_bar_config(disable=None)
494

495
        prompt = "A painting of a squirrel eating a burger"
496
        generator = torch.Generator(device="cpu").manual_seed(0)
497
        image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images[0]
498

499
        self.assertEqual(image.shape, (64, 64, 3))
500

501
    def test_stable_diffusion_default(self):
502
        components = self.get_dummy_components()
503

504
        sd_pipe = GaudiStableDiffusionPipeline(
505
            use_habana=True,
506
            gaudi_config="Habana/stable-diffusion",
507
            **components,
508
        )
509
        sd_pipe.set_progress_bar_config(disable=None)
510

511
        prompt = "A painting of a squirrel eating a burger"
512
        generator = torch.Generator(device="cpu").manual_seed(0)
513
        images = sd_pipe(
514
            [prompt] * 2,
515
            generator=generator,
516
            num_inference_steps=2,
517
            output_type="np",
518
            batch_size=3,
519
            num_images_per_prompt=5,
520
        ).images
521

522
        self.assertEqual(len(images), 10)
523
        self.assertEqual(images[-1].shape, (64, 64, 3))
524

525
    def test_stable_diffusion_hpu_graphs(self):
526
        components = self.get_dummy_components()
527

528
        sd_pipe = GaudiStableDiffusionPipeline(
529
            use_habana=True,
530
            use_hpu_graphs=True,
531
            gaudi_config="Habana/stable-diffusion",
532
            **components,
533
        )
534
        sd_pipe.set_progress_bar_config(disable=None)
535

536
        prompt = "A painting of a squirrel eating a burger"
537
        generator = torch.Generator(device="cpu").manual_seed(0)
538
        images = sd_pipe(
539
            [prompt] * 2,
540
            generator=generator,
541
            num_inference_steps=2,
542
            output_type="np",
543
            batch_size=3,
544
            num_images_per_prompt=5,
545
        ).images
546

547
        self.assertEqual(len(images), 10)
548
        self.assertEqual(images[-1].shape, (64, 64, 3))
549

550
    @slow
551
    def test_no_throughput_regression_bf16(self):
552
        prompts = [
553
            "An image of a squirrel in Picasso style",
554
            "High quality photo of an astronaut riding a horse in space",
555
        ]
556
        num_images_per_prompt = 11
557
        batch_size = 4
558
        model_name = "runwayml/stable-diffusion-v1-5"
559
        scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
560

561
        pipeline = GaudiStableDiffusionPipeline.from_pretrained(
562
            model_name,
563
            scheduler=scheduler,
564
            use_habana=True,
565
            use_hpu_graphs=True,
566
            gaudi_config=GaudiConfig.from_pretrained("Habana/stable-diffusion"),
567
            torch_dtype=torch.bfloat16,
568
        )
569
        set_seed(27)
570
        outputs = pipeline(
571
            prompt=prompts,
572
            num_images_per_prompt=num_images_per_prompt,
573
            batch_size=batch_size,
574
        )
575
        self.assertEqual(len(outputs.images), num_images_per_prompt * len(prompts))
576
        self.assertGreaterEqual(outputs.throughput, 0.95 * THROUGHPUT_BASELINE_BF16)
577

578
    @custom_bf16_ops
579
    @slow
580
    def test_no_throughput_regression_autocast(self):
581
        prompts = [
582
            "An image of a squirrel in Picasso style",
583
            "High quality photo of an astronaut riding a horse in space",
584
        ]
585
        num_images_per_prompt = 11
586
        batch_size = 4
587
        model_name = "stabilityai/stable-diffusion-2-1"
588
        scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
589

590
        pipeline = GaudiStableDiffusionPipeline.from_pretrained(
591
            model_name,
592
            scheduler=scheduler,
593
            use_habana=True,
594
            use_hpu_graphs=True,
595
            gaudi_config=GaudiConfig.from_pretrained("Habana/stable-diffusion-2"),
596
        )
597
        set_seed(27)
598
        outputs = pipeline(
599
            prompt=prompts,
600
            num_images_per_prompt=num_images_per_prompt,
601
            batch_size=batch_size,
602
            height=768,
603
            width=768,
604
        )
605
        self.assertEqual(len(outputs.images), num_images_per_prompt * len(prompts))
606
        self.assertGreaterEqual(outputs.throughput, 0.95 * THROUGHPUT_BASELINE_AUTOCAST)
607

608
    @slow
609
    def test_no_generation_regression(self):
610
        model_name = "CompVis/stable-diffusion-v1-4"
611
        # fp32
612
        scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
613
        pipeline = GaudiStableDiffusionPipeline.from_pretrained(
614
            model_name,
615
            scheduler=scheduler,
616
            safety_checker=None,
617
            use_habana=True,
618
            use_hpu_graphs=True,
619
            gaudi_config=GaudiConfig(use_torch_autocast=False),
620
        )
621
        set_seed(27)
622
        outputs = pipeline(
623
            prompt="An image of a squirrel in Picasso style",
624
            output_type="np",
625
        )
626

627
        if os.environ.get("GAUDI2_CI", "0") == "1":
628
            expected_slice = np.array(
629
                [
630
                    0.68306947,
631
                    0.6812112,
632
                    0.67309505,
633
                    0.70057267,
634
                    0.6582885,
635
                    0.6325019,
636
                    0.6708976,
637
                    0.6226433,
638
                    0.58038336,
639
                ]
640
            )
641
        else:
642
            expected_slice = np.array(
643
                [0.70760196, 0.7136303, 0.7000798, 0.714934, 0.6776865, 0.6800843, 0.6923707, 0.6653969, 0.6408076]
644
            )
645
        image = outputs.images[0]
646

647
        self.assertEqual(image.shape, (512, 512, 3))
648
        self.assertLess(np.abs(expected_slice - image[-3:, -3:, -1].flatten()).max(), 5e-3)
649

650
    @slow
651
    def test_no_generation_regression_ldm3d(self):
652
        model_name = "Intel/ldm3d-4c"
653
        # fp32
654
        scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
655
        pipeline = GaudiStableDiffusionLDM3DPipeline.from_pretrained(
656
            model_name,
657
            scheduler=scheduler,
658
            safety_checker=None,
659
            use_habana=True,
660
            use_hpu_graphs=True,
661
            gaudi_config=GaudiConfig(),
662
        )
663
        set_seed(27)
664
        outputs = pipeline(
665
            prompt="An image of a squirrel in Picasso style",
666
            output_type="np",
667
        )
668

669
        if os.environ.get("GAUDI2_CI", "0") == "1":
670
            expected_slice_rgb = np.array(
671
                [
672
                    0.2099357,
673
                    0.16664368,
674
                    0.08352646,
675
                    0.20643419,
676
                    0.16748399,
677
                    0.08781305,
678
                    0.21379063,
679
                    0.19943115,
680
                    0.04389626,
681
                ]
682
            )
683
            expected_slice_depth = np.array(
684
                [
685
                    0.68369114,
686
                    0.6827824,
687
                    0.6852779,
688
                    0.6836072,
689
                    0.6888298,
690
                    0.6895473,
691
                    0.6853674,
692
                    0.67561126,
693
                    0.660434,
694
                ]
695
            )
696
        else:
697
            expected_slice_rgb = np.array([0.7083766, 1.0, 1.0, 0.70610344, 0.9867363, 1.0, 0.7214538, 1.0, 1.0])
698
            expected_slice_depth = np.array(
699
                [
700
                    0.919621,
701
                    0.92072034,
702
                    0.9184986,
703
                    0.91994286,
704
                    0.9242079,
705
                    0.93387043,
706
                    0.92345214,
707
                    0.93558526,
708
                    0.9223714,
709
                ]
710
            )
711
        rgb = outputs.rgb[0]
712
        depth = outputs.depth[0]
713

714
        self.assertEqual(rgb.shape, (512, 512, 3))
715
        self.assertEqual(depth.shape, (512, 512, 1))
716
        self.assertLess(np.abs(expected_slice_rgb - rgb[-3:, -3:, -1].flatten()).max(), 5e-3)
717
        self.assertLess(np.abs(expected_slice_depth - depth[-3:, -3:, -1].flatten()).max(), 5e-3)
718

719
    @slow
720
    def test_no_generation_regression_upscale(self):
721
        model_name = "stabilityai/stable-diffusion-x4-upscaler"
722
        # fp32
723
        scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
724
        pipeline = GaudiStableDiffusionUpscalePipeline.from_pretrained(
725
            model_name,
726
            scheduler=scheduler,
727
            use_habana=True,
728
            use_hpu_graphs=True,
729
            gaudi_config=GaudiConfig(use_torch_autocast=False),
730
        )
731
        set_seed(27)
732

733
        url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
734
        response = requests.get(url)
735
        low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
736
        low_res_img = low_res_img.resize((128, 128))
737
        prompt = "a white cat"
738
        upscaled_image = pipeline(prompt=prompt, image=low_res_img, output_type="np").images[0]
739
        if os.environ.get("GAUDI2_CI", "0") == "1":
740
            expected_slice = np.array(
741
                [
742
                    0.16527882,
743
                    0.161616,
744
                    0.15665859,
745
                    0.1660901,
746
                    0.1594379,
747
                    0.14936888,
748
                    0.1578255,
749
                    0.15342498,
750
                    0.14590919,
751
                ]
752
            )
753
        else:
754
            expected_slice = np.array(
755
                [
756
                    0.1652787,
757
                    0.16161594,
758
                    0.15665877,
759
                    0.16608998,
760
                    0.1594378,
761
                    0.14936894,
762
                    0.15782538,
763
                    0.15342498,
764
                    0.14590913,
765
                ]
766
            )
767
        self.assertEqual(upscaled_image.shape, (512, 512, 3))
768
        self.assertLess(np.abs(expected_slice - upscaled_image[-3:, -3:, -1].flatten()).max(), 5e-3)
769

770
    @slow
771
    def test_textual_inversion(self):
772
        path_to_script = (
773
            Path(os.path.dirname(__file__)).parent
774
            / "examples"
775
            / "stable-diffusion"
776
            / "training"
777
            / "textual_inversion.py"
778
        )
779

780
        with tempfile.TemporaryDirectory() as data_dir:
781
            snapshot_download(
782
                "diffusers/cat_toy_example", local_dir=data_dir, repo_type="dataset", ignore_patterns=".gitattributes"
783
            )
784
            with tempfile.TemporaryDirectory() as run_dir:
785
                cmd_line = [
786
                    "python3",
787
                    f"{path_to_script.parent.parent.parent / 'gaudi_spawn.py'}",
788
                    "--use_mpi",
789
                    "--world_size",
790
                    "8",
791
                    f"{path_to_script}",
792
                    "--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5",
793
                    f"--train_data_dir {data_dir}",
794
                    '--learnable_property "object"',
795
                    '--placeholder_token "<cat-toy>"',
796
                    '--initializer_token "toy"',
797
                    "--resolution 512",
798
                    "--train_batch_size 4",
799
                    "--max_train_steps 375",
800
                    "--learning_rate 5.0e-04",
801
                    "--scale_lr",
802
                    '--lr_scheduler "constant"',
803
                    "--lr_warmup_steps 0",
804
                    f"--output_dir {run_dir}",
805
                    "--save_as_full_pipeline",
806
                    "--gaudi_config_name Habana/stable-diffusion",
807
                    "--throughput_warmup_steps 3",
808
                    "--seed 27",
809
                ]
810
                pattern = re.compile(r"([\"\'].+?[\"\'])|\s")
811
                cmd_line = [x for y in cmd_line for x in re.split(pattern, y) if x]
812

813
                # Run textual inversion
814
                p = subprocess.Popen(cmd_line)
815
                return_code = p.wait()
816

817
                # Ensure the run finished without any issue
818
                self.assertEqual(return_code, 0)
819

820
                # Assess throughput
821
                with open(Path(run_dir) / "speed_metrics.json") as fp:
822
                    results = json.load(fp)
823
                self.assertGreaterEqual(results["train_samples_per_second"], 0.95 * TEXTUAL_INVERSION_THROUGHPUT)
824
                self.assertLessEqual(results["train_runtime"], 1.05 * TEXTUAL_INVERSION_RUNTIME)
825

826
                # Assess generated image
827
                pipe = GaudiStableDiffusionPipeline.from_pretrained(
828
                    run_dir,
829
                    torch_dtype=torch.bfloat16,
830
                    use_habana=True,
831
                    use_hpu_graphs=True,
832
                    gaudi_config=GaudiConfig(use_habana_mixed_precision=False),
833
                )
834
                prompt = "A <cat-toy> backpack"
835
                set_seed(27)
836
                image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5, output_type="np").images[0]
837

838
                # TODO: see how to generate images in a reproducible way
839
                # expected_slice = np.array(
840
                #     [0.57421875, 0.5703125, 0.58203125, 0.58203125, 0.578125, 0.5859375, 0.578125, 0.57421875, 0.56640625]
841
                # )
842
                self.assertEqual(image.shape, (512, 512, 3))
843
                # self.assertLess(np.abs(expected_slice - image[-3:, -3:, -1].flatten()).max(), 5e-3)
844

845

846
class GaudiStableDiffusionXLPipelineTester(TestCase):
847
    """
848
    Tests the StableDiffusionXLPipeline for Gaudi.
849
    """
850

851
    def get_dummy_components(self, time_cond_proj_dim=None, timestep_spacing="leading"):
852
        torch.manual_seed(0)
853
        unet = UNet2DConditionModel(
854
            block_out_channels=(2, 4),
855
            layers_per_block=2,
856
            time_cond_proj_dim=time_cond_proj_dim,
857
            sample_size=32,
858
            in_channels=4,
859
            out_channels=4,
860
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
861
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
862
            # SD2-specific config below
863
            attention_head_dim=(2, 4),
864
            use_linear_projection=True,
865
            addition_embed_type="text_time",
866
            addition_time_embed_dim=8,
867
            transformer_layers_per_block=(1, 2),
868
            projection_class_embeddings_input_dim=80,  # 6 * 8 + 32
869
            cross_attention_dim=64,
870
            norm_num_groups=1,
871
        )
872
        scheduler = GaudiEulerDiscreteScheduler(
873
            beta_start=0.00085,
874
            beta_end=0.012,
875
            steps_offset=1,
876
            beta_schedule="scaled_linear",
877
            timestep_spacing=timestep_spacing,
878
        )
879
        torch.manual_seed(0)
880
        vae = AutoencoderKL(
881
            block_out_channels=[32, 64],
882
            in_channels=3,
883
            out_channels=3,
884
            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
885
            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
886
            latent_channels=4,
887
            sample_size=128,
888
        )
889
        torch.manual_seed(0)
890
        text_encoder_config = CLIPTextConfig(
891
            bos_token_id=0,
892
            eos_token_id=2,
893
            hidden_size=32,
894
            intermediate_size=37,
895
            layer_norm_eps=1e-05,
896
            num_attention_heads=4,
897
            num_hidden_layers=5,
898
            pad_token_id=1,
899
            vocab_size=1000,
900
            # SD2-specific config below
901
            hidden_act="gelu",
902
            projection_dim=32,
903
        )
904
        text_encoder = CLIPTextModel(text_encoder_config)
905
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
906

907
        text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
908
        tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
909

910
        components = {
911
            "unet": unet,
912
            "scheduler": scheduler,
913
            "vae": vae,
914
            "text_encoder": text_encoder,
915
            "tokenizer": tokenizer,
916
            "text_encoder_2": text_encoder_2,
917
            "tokenizer_2": tokenizer_2,
918
            "image_encoder": None,
919
            "feature_extractor": None,
920
        }
921
        return components
922

923
    def get_dummy_inputs(self, device, seed=0):
924
        generator = torch.Generator(device=device).manual_seed(seed)
925
        inputs = {
926
            "prompt": "A painting of a squirrel eating a burger",
927
            "generator": generator,
928
            "num_inference_steps": 2,
929
            "guidance_scale": 5.0,
930
            "output_type": "np",
931
        }
932
        return inputs
933

934
    def test_stable_diffusion_xl_euler(self):
935
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
936
        components = self.get_dummy_components()
937
        gaudi_config = GaudiConfig(use_torch_autocast=False)
938
        sd_pipe = GaudiStableDiffusionXLPipeline(use_habana=True, gaudi_config=gaudi_config, **components)
939
        sd_pipe.set_progress_bar_config(disable=None)
940

941
        inputs = self.get_dummy_inputs(device)
942
        image = sd_pipe(**inputs).images[0]
943

944
        image_slice = image[-3:, -3:, -1]
945

946
        self.assertEqual(image.shape, (64, 64, 3))
947
        expected_slice = np.array([0.5552, 0.5569, 0.4725, 0.4348, 0.4994, 0.4632, 0.5142, 0.5012, 0.47])
948

949
        # The threshold should be 1e-2 below but it started failing
950
        # from Diffusers v0.24. However, generated images still look similar.
951
        self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-1)
952

953
    def test_stable_diffusion_xl_euler_ancestral(self):
954
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
955
        components = self.get_dummy_components()
956
        gaudi_config = GaudiConfig(use_torch_autocast=False)
957
        sd_pipe = GaudiStableDiffusionXLPipeline(use_habana=True, gaudi_config=gaudi_config, **components)
958
        sd_pipe.scheduler = GaudiEulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
959
        sd_pipe.set_progress_bar_config(disable=None)
960

961
        inputs = self.get_dummy_inputs(device)
962
        image = sd_pipe(**inputs).images[0]
963

964
        image_slice = image[-3:, -3:, -1]
965

966
        self.assertEqual(image.shape, (64, 64, 3))
967
        expected_slice = np.array([0.4675, 0.5173, 0.4611, 0.4067, 0.5250, 0.4674, 0.5446, 0.5094, 0.4791])
968
        self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2)
969

970
    def test_stable_diffusion_xl_turbo_euler_ancestral(self):
971
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
972
        components = self.get_dummy_components(timestep_spacing="trailing")
973
        gaudi_config = GaudiConfig(use_torch_autocast=False)
974

975
        sd_pipe = GaudiStableDiffusionXLPipeline(use_habana=True, gaudi_config=gaudi_config, **components)
976
        sd_pipe.scheduler = GaudiEulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
977

978
        sd_pipe.set_progress_bar_config(disable=None)
979

980
        inputs = self.get_dummy_inputs(device)
981
        image = sd_pipe(**inputs).images[0]
982

983
        image_slice = image[-3:, -3:, -1]
984

985
        self.assertEqual(image.shape, (64, 64, 3))
986
        expected_slice = np.array([0.4675, 0.5173, 0.4611, 0.4067, 0.5250, 0.4674, 0.5446, 0.5094, 0.4791])
987
        self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2)
988

989
    @parameterized.expand(["pil", "np", "latent"])
990
    def test_stable_diffusion_xl_output_types(self, output_type):
991
        components = self.get_dummy_components()
992
        gaudi_config = GaudiConfig()
993

994
        sd_pipe = GaudiStableDiffusionXLPipeline(
995
            use_habana=True,
996
            gaudi_config=gaudi_config,
997
            **components,
998
        )
999
        sd_pipe.set_progress_bar_config(disable=None)
1000

1001
        prompt = "A painting of a squirrel eating a burger"
1002
        num_prompts = 2
1003
        num_images_per_prompt = 3
1004

1005
        outputs = sd_pipe(
1006
            num_prompts * [prompt],
1007
            num_images_per_prompt=num_images_per_prompt,
1008
            num_inference_steps=2,
1009
            output_type=output_type,
1010
        )
1011

1012
        self.assertEqual(len(outputs.images), 2 * 3)
1013

1014
    def test_stable_diffusion_xl_num_images_per_prompt(self):
1015
        components = self.get_dummy_components()
1016
        gaudi_config = GaudiConfig()
1017

1018
        sd_pipe = GaudiStableDiffusionXLPipeline(
1019
            use_habana=True,
1020
            gaudi_config=gaudi_config,
1021
            **components,
1022
        )
1023
        sd_pipe.set_progress_bar_config(disable=None)
1024

1025
        prompt = "A painting of a squirrel eating a burger"
1026

1027
        # Test num_images_per_prompt=1 (default)
1028
        images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images
1029

1030
        self.assertEqual(len(images), 1)
1031
        self.assertEqual(images[0].shape, (64, 64, 3))
1032

1033
        # Test num_images_per_prompt=1 (default) for several prompts
1034
        num_prompts = 3
1035
        images = sd_pipe([prompt] * num_prompts, num_inference_steps=2, output_type="np").images
1036

1037
        self.assertEqual(len(images), num_prompts)
1038
        self.assertEqual(images[-1].shape, (64, 64, 3))
1039

1040
        # Test num_images_per_prompt for single prompt
1041
        num_images_per_prompt = 2
1042
        images = sd_pipe(
1043
            prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
1044
        ).images
1045

1046
        self.assertEqual(len(images), num_images_per_prompt)
1047
        self.assertEqual(images[-1].shape, (64, 64, 3))
1048

1049
        # Test num_images_per_prompt for several prompts
1050
        num_prompts = 2
1051
        images = sd_pipe(
1052
            [prompt] * num_prompts,
1053
            num_inference_steps=2,
1054
            output_type="np",
1055
            num_images_per_prompt=num_images_per_prompt,
1056
        ).images
1057

1058
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
1059
        self.assertEqual(images[-1].shape, (64, 64, 3))
1060

1061
    def test_stable_diffusion_xl_batch_sizes(self):
1062
        components = self.get_dummy_components()
1063
        gaudi_config = GaudiConfig()
1064

1065
        sd_pipe = GaudiStableDiffusionXLPipeline(
1066
            use_habana=True,
1067
            gaudi_config=gaudi_config,
1068
            **components,
1069
        )
1070
        sd_pipe.set_progress_bar_config(disable=None)
1071

1072
        prompt = "A painting of a squirrel eating a burger"
1073

1074
        # Test batch_size > 1 where batch_size is a divider of the total number of generated images
1075
        batch_size = 3
1076
        num_images_per_prompt = batch_size**2
1077
        images = sd_pipe(
1078
            prompt,
1079
            num_inference_steps=2,
1080
            output_type="np",
1081
            batch_size=batch_size,
1082
            num_images_per_prompt=num_images_per_prompt,
1083
        ).images
1084
        self.assertEqual(len(images), num_images_per_prompt)
1085
        self.assertEqual(images[-1].shape, (64, 64, 3))
1086

1087
        # Same test for several prompts
1088
        num_prompts = 3
1089
        images = sd_pipe(
1090
            [prompt] * num_prompts,
1091
            num_inference_steps=2,
1092
            output_type="np",
1093
            batch_size=batch_size,
1094
            num_images_per_prompt=num_images_per_prompt,
1095
        ).images
1096

1097
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
1098
        self.assertEqual(images[-1].shape, (64, 64, 3))
1099

1100
        # Test batch_size when it is not a divider of the total number of generated images for a single prompt
1101
        num_images_per_prompt = 7
1102
        images = sd_pipe(
1103
            prompt,
1104
            num_inference_steps=2,
1105
            output_type="np",
1106
            batch_size=batch_size,
1107
            num_images_per_prompt=num_images_per_prompt,
1108
        ).images
1109

1110
        self.assertEqual(len(images), num_images_per_prompt)
1111
        self.assertEqual(images[-1].shape, (64, 64, 3))
1112

1113
        # Same test for several prompts
1114
        num_prompts = 2
1115
        images = sd_pipe(
1116
            [prompt] * num_prompts,
1117
            num_inference_steps=2,
1118
            output_type="np",
1119
            batch_size=batch_size,
1120
            num_images_per_prompt=num_images_per_prompt,
1121
        ).images
1122

1123
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
1124
        self.assertEqual(images[-1].shape, (64, 64, 3))
1125

1126
    def test_stable_diffusion_xl_bf16(self):
1127
        """Test that stable diffusion works with bf16"""
1128
        components = self.get_dummy_components()
1129
        gaudi_config = GaudiConfig()
1130

1131
        sd_pipe = GaudiStableDiffusionXLPipeline(
1132
            use_habana=True,
1133
            gaudi_config=gaudi_config,
1134
            **components,
1135
        )
1136
        sd_pipe.set_progress_bar_config(disable=None)
1137

1138
        prompt = "A painting of a squirrel eating a burger"
1139
        generator = torch.Generator(device="cpu").manual_seed(0)
1140
        image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images[0]
1141

1142
        self.assertEqual(image.shape, (64, 64, 3))
1143

1144
    def test_stable_diffusion_xl_default(self):
1145
        components = self.get_dummy_components()
1146

1147
        sd_pipe = GaudiStableDiffusionXLPipeline(
1148
            use_habana=True,
1149
            gaudi_config="Habana/stable-diffusion",
1150
            **components,
1151
        )
1152
        sd_pipe.set_progress_bar_config(disable=None)
1153

1154
        prompt = "A painting of a squirrel eating a burger"
1155
        generator = torch.Generator(device="cpu").manual_seed(0)
1156
        images = sd_pipe(
1157
            [prompt] * 2,
1158
            generator=generator,
1159
            num_inference_steps=2,
1160
            output_type="np",
1161
            batch_size=3,
1162
            num_images_per_prompt=5,
1163
        ).images
1164

1165
        self.assertEqual(len(images), 10)
1166
        self.assertEqual(images[-1].shape, (64, 64, 3))
1167

1168
    def test_stable_diffusion_xl_hpu_graphs(self):
1169
        components = self.get_dummy_components()
1170

1171
        sd_pipe = GaudiStableDiffusionXLPipeline(
1172
            use_habana=True,
1173
            use_hpu_graphs=True,
1174
            gaudi_config="Habana/stable-diffusion",
1175
            **components,
1176
        )
1177
        sd_pipe.set_progress_bar_config(disable=None)
1178

1179
        prompt = "A painting of a squirrel eating a burger"
1180
        generator = torch.Generator(device="cpu").manual_seed(0)
1181
        images = sd_pipe(
1182
            [prompt] * 2,
1183
            generator=generator,
1184
            num_inference_steps=2,
1185
            output_type="np",
1186
            batch_size=3,
1187
            num_images_per_prompt=5,
1188
        ).images
1189

1190
        self.assertEqual(len(images), 10)
1191
        self.assertEqual(images[-1].shape, (64, 64, 3))
1192

1193

1194
class GaudiStableDiffusionControlNetPipelineTester(TestCase):
1195
    """
1196
    Tests the StableDiffusionControlNetPipeline for Gaudi.
1197
    """
1198

1199
    def get_dummy_components(self, time_cond_proj_dim=None):
1200
        torch.manual_seed(0)
1201
        unet = UNet2DConditionModel(
1202
            block_out_channels=(4, 8),
1203
            layers_per_block=2,
1204
            sample_size=32,
1205
            time_cond_proj_dim=time_cond_proj_dim,
1206
            in_channels=4,
1207
            out_channels=4,
1208
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
1209
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
1210
            cross_attention_dim=32,
1211
            norm_num_groups=1,
1212
        )
1213

1214
        def init_weights(m):
1215
            if isinstance(m, torch.nn.Conv2d):
1216
                torch.nn.init.normal(m.weight)
1217
                m.bias.data.fill_(1.0)
1218

1219
        torch.manual_seed(0)
1220
        controlnet = ControlNetModel(
1221
            block_out_channels=(4, 8),
1222
            layers_per_block=2,
1223
            in_channels=4,
1224
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
1225
            cross_attention_dim=32,
1226
            conditioning_embedding_out_channels=(16, 32),
1227
            norm_num_groups=1,
1228
        )
1229
        controlnet.controlnet_down_blocks.apply(init_weights)
1230

1231
        scheduler = GaudiDDIMScheduler(
1232
            beta_start=0.00085,
1233
            beta_end=0.012,
1234
            beta_schedule="scaled_linear",
1235
            clip_sample=False,
1236
            set_alpha_to_one=False,
1237
        )
1238
        torch.manual_seed(0)
1239
        vae = AutoencoderKL(
1240
            block_out_channels=[4, 8],
1241
            in_channels=3,
1242
            out_channels=3,
1243
            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
1244
            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
1245
            latent_channels=4,
1246
            norm_num_groups=2,
1247
        )
1248
        torch.manual_seed(0)
1249
        text_encoder_config = CLIPTextConfig(
1250
            bos_token_id=0,
1251
            eos_token_id=2,
1252
            hidden_size=32,
1253
            intermediate_size=37,
1254
            layer_norm_eps=1e-05,
1255
            num_attention_heads=4,
1256
            num_hidden_layers=5,
1257
            pad_token_id=1,
1258
            vocab_size=1000,
1259
        )
1260
        text_encoder = CLIPTextModel(text_encoder_config)
1261
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
1262

1263
        components = {
1264
            "unet": unet,
1265
            "controlnet": controlnet,
1266
            "scheduler": scheduler,
1267
            "vae": vae,
1268
            "text_encoder": text_encoder,
1269
            "tokenizer": tokenizer,
1270
            "safety_checker": None,
1271
            "feature_extractor": None,
1272
        }
1273
        return components
1274

1275
    def get_dummy_inputs(self, device, seed=0):
1276
        generator = torch.Generator(device=device).manual_seed(seed)
1277
        controlnet_embedder_scale_factor = 2
1278
        images = [
1279
            randn_tensor(
1280
                (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
1281
                generator=generator,
1282
                device=torch.device(device),
1283
            ),
1284
        ]
1285
        inputs = {
1286
            "prompt": "A painting of a squirrel eating a burger",
1287
            "generator": generator,
1288
            "num_inference_steps": 2,
1289
            "guidance_scale": 6.0,
1290
            "output_type": "np",
1291
            "image": images,
1292
        }
1293
        return inputs
1294

1295
    def test_stable_diffusion_controlnet_num_images_per_prompt(self):
1296
        components = self.get_dummy_components()
1297
        gaudi_config = GaudiConfig()
1298

1299
        sd_pipe = GaudiStableDiffusionControlNetPipeline(
1300
            use_habana=True,
1301
            gaudi_config=gaudi_config,
1302
            **components,
1303
        )
1304
        sd_pipe.set_progress_bar_config(disable=None)
1305

1306
        inputs = self.get_dummy_inputs(device="cpu")
1307
        prompt = inputs["prompt"]
1308
        # Test num_images_per_prompt=1 (default)
1309
        images = sd_pipe(**inputs).images
1310

1311
        self.assertEqual(len(images), 1)
1312
        self.assertEqual(images[0].shape, (64, 64, 3))
1313

1314
        # Test num_images_per_prompt=1 (default) for several prompts
1315
        num_prompts = 3
1316
        inputs["prompt"] = [prompt] * num_prompts
1317
        images = sd_pipe(**inputs).images
1318

1319
        self.assertEqual(len(images), num_prompts)
1320
        self.assertEqual(images[-1].shape, (64, 64, 3))
1321

1322
        # Test num_images_per_prompt for single prompt
1323
        num_images_per_prompt = 2
1324
        inputs["prompt"] = prompt
1325
        images = sd_pipe(num_images_per_prompt=num_images_per_prompt, **inputs).images
1326

1327
        self.assertEqual(len(images), num_images_per_prompt)
1328
        self.assertEqual(images[-1].shape, (64, 64, 3))
1329

1330
        ## Test num_images_per_prompt for several prompts
1331
        num_prompts = 2
1332
        inputs["prompt"] = [prompt] * num_prompts
1333
        images = sd_pipe(num_images_per_prompt=num_images_per_prompt, **inputs).images
1334

1335
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
1336
        self.assertEqual(images[-1].shape, (64, 64, 3))
1337

1338
    def test_stable_diffusion_controlnet_batch_sizes(self):
1339
        components = self.get_dummy_components()
1340
        gaudi_config = GaudiConfig()
1341

1342
        sd_pipe = GaudiStableDiffusionControlNetPipeline(
1343
            use_habana=True,
1344
            gaudi_config=gaudi_config,
1345
            **components,
1346
        )
1347
        sd_pipe.set_progress_bar_config(disable=None)
1348

1349
        inputs = self.get_dummy_inputs(device="cpu")
1350
        prompt = inputs["prompt"]
1351
        # Test batch_size > 1 where batch_size is a divider of the total number of generated images
1352
        batch_size = 3
1353
        num_images_per_prompt = batch_size**2
1354
        images = sd_pipe(
1355
            batch_size=batch_size,
1356
            num_images_per_prompt=num_images_per_prompt,
1357
            **inputs,
1358
        ).images
1359
        self.assertEqual(len(images), num_images_per_prompt)
1360
        self.assertEqual(images[-1].shape, (64, 64, 3))
1361

1362
        # Same test for several prompts
1363
        num_prompts = 3
1364
        inputs["prompt"] = [prompt] * num_prompts
1365

1366
        images = sd_pipe(
1367
            batch_size=batch_size,
1368
            num_images_per_prompt=num_images_per_prompt,
1369
            **inputs,
1370
        ).images
1371

1372
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
1373
        self.assertEqual(images[-1].shape, (64, 64, 3))
1374

1375
        inputs["prompt"] = prompt
1376
        # Test batch_size when it is not a divider of the total number of generated images for a single prompt
1377
        num_images_per_prompt = 7
1378
        images = sd_pipe(
1379
            batch_size=batch_size,
1380
            num_images_per_prompt=num_images_per_prompt,
1381
            **inputs,
1382
        ).images
1383

1384
        self.assertEqual(len(images), num_images_per_prompt)
1385
        self.assertEqual(images[-1].shape, (64, 64, 3))
1386

1387
        # Same test for several prompts
1388
        num_prompts = 2
1389
        inputs["prompt"] = [prompt] * num_prompts
1390
        images = sd_pipe(batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, **inputs).images
1391

1392
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
1393
        self.assertEqual(images[-1].shape, (64, 64, 3))
1394

1395
    def test_stable_diffusion_controlnet_bf16(self):
1396
        """Test that stable diffusion works with bf16"""
1397
        components = self.get_dummy_components()
1398
        gaudi_config = GaudiConfig()
1399

1400
        sd_pipe = GaudiStableDiffusionControlNetPipeline(
1401
            use_habana=True,
1402
            gaudi_config=gaudi_config,
1403
            **components,
1404
        )
1405
        sd_pipe.set_progress_bar_config(disable=None)
1406

1407
        inputs = self.get_dummy_inputs(device="cpu")
1408
        image = sd_pipe(**inputs).images[0]
1409

1410
        self.assertEqual(image.shape, (64, 64, 3))
1411

1412
    def test_stable_diffusion_controlnet_default(self):
1413
        components = self.get_dummy_components()
1414

1415
        sd_pipe = GaudiStableDiffusionControlNetPipeline(
1416
            use_habana=True,
1417
            gaudi_config="Habana/stable-diffusion",
1418
            **components,
1419
        )
1420
        sd_pipe.set_progress_bar_config(disable=None)
1421

1422
        inputs = self.get_dummy_inputs(device="cpu")
1423
        inputs["prompt"] = [inputs["prompt"]] * 2
1424
        images = sd_pipe(
1425
            batch_size=3,
1426
            num_images_per_prompt=5,
1427
            **inputs,
1428
        ).images
1429

1430
        self.assertEqual(len(images), 10)
1431
        self.assertEqual(images[-1].shape, (64, 64, 3))
1432

1433
    def test_stable_diffusion_controlnet_hpu_graphs(self):
1434
        components = self.get_dummy_components()
1435

1436
        sd_pipe = GaudiStableDiffusionControlNetPipeline(
1437
            use_habana=True,
1438
            use_hpu_graphs=True,
1439
            gaudi_config="Habana/stable-diffusion",
1440
            **components,
1441
        )
1442
        sd_pipe.set_progress_bar_config(disable=None)
1443

1444
        inputs = self.get_dummy_inputs(device="cpu")
1445
        inputs["prompt"] = [inputs["prompt"]] * 2
1446

1447
        images = sd_pipe(
1448
            batch_size=3,
1449
            num_images_per_prompt=5,
1450
            **inputs,
1451
        ).images
1452

1453
        self.assertEqual(len(images), 10)
1454
        self.assertEqual(images[-1].shape, (64, 64, 3))
1455

1456

1457
class GaudiStableDiffusionMultiControlNetPipelineTester(TestCase):
1458
    """
1459
    Tests the StableDiffusionControlNetPipeline for Gaudi.
1460
    """
1461

1462
    def get_dummy_components(self, time_cond_proj_dim=None):
1463
        torch.manual_seed(0)
1464
        unet = UNet2DConditionModel(
1465
            block_out_channels=(4, 8),
1466
            layers_per_block=2,
1467
            sample_size=32,
1468
            time_cond_proj_dim=time_cond_proj_dim,
1469
            in_channels=4,
1470
            out_channels=4,
1471
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
1472
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
1473
            cross_attention_dim=32,
1474
            norm_num_groups=1,
1475
        )
1476

1477
        def init_weights(m):
1478
            if isinstance(m, torch.nn.Conv2d):
1479
                torch.nn.init.normal(m.weight)
1480
                m.bias.data.fill_(1.0)
1481

1482
        torch.manual_seed(0)
1483
        controlnet1 = ControlNetModel(
1484
            block_out_channels=(4, 8),
1485
            layers_per_block=2,
1486
            in_channels=4,
1487
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
1488
            cross_attention_dim=32,
1489
            conditioning_embedding_out_channels=(16, 32),
1490
            norm_num_groups=1,
1491
        )
1492
        controlnet1.controlnet_down_blocks.apply(init_weights)
1493

1494
        torch.manual_seed(0)
1495
        controlnet2 = ControlNetModel(
1496
            block_out_channels=(4, 8),
1497
            layers_per_block=2,
1498
            in_channels=4,
1499
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
1500
            cross_attention_dim=32,
1501
            conditioning_embedding_out_channels=(16, 32),
1502
            norm_num_groups=1,
1503
        )
1504
        controlnet2.controlnet_down_blocks.apply(init_weights)
1505

1506
        scheduler = GaudiDDIMScheduler(
1507
            beta_start=0.00085,
1508
            beta_end=0.012,
1509
            beta_schedule="scaled_linear",
1510
            clip_sample=False,
1511
            set_alpha_to_one=False,
1512
        )
1513
        torch.manual_seed(0)
1514
        vae = AutoencoderKL(
1515
            block_out_channels=[4, 8],
1516
            in_channels=3,
1517
            out_channels=3,
1518
            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
1519
            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
1520
            latent_channels=4,
1521
            norm_num_groups=2,
1522
        )
1523
        torch.manual_seed(0)
1524
        text_encoder_config = CLIPTextConfig(
1525
            bos_token_id=0,
1526
            eos_token_id=2,
1527
            hidden_size=32,
1528
            intermediate_size=37,
1529
            layer_norm_eps=1e-05,
1530
            num_attention_heads=4,
1531
            num_hidden_layers=5,
1532
            pad_token_id=1,
1533
            vocab_size=1000,
1534
        )
1535
        text_encoder = CLIPTextModel(text_encoder_config)
1536
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
1537

1538
        controlnet = MultiControlNetModel([controlnet1, controlnet2])
1539

1540
        components = {
1541
            "unet": unet,
1542
            "controlnet": controlnet,
1543
            "scheduler": scheduler,
1544
            "vae": vae,
1545
            "text_encoder": text_encoder,
1546
            "tokenizer": tokenizer,
1547
            "safety_checker": None,
1548
            "feature_extractor": None,
1549
        }
1550
        return components
1551

1552
    def get_dummy_inputs(self, device, seed=0):
1553
        generator = torch.Generator(device=device).manual_seed(seed)
1554
        controlnet_embedder_scale_factor = 2
1555
        images = [
1556
            randn_tensor(
1557
                (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
1558
                generator=generator,
1559
                device=torch.device(device),
1560
            ),
1561
            randn_tensor(
1562
                (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
1563
                generator=generator,
1564
                device=torch.device(device),
1565
            ),
1566
        ]
1567
        inputs = {
1568
            "prompt": "A painting of a squirrel eating a burger",
1569
            "generator": generator,
1570
            "num_inference_steps": 2,
1571
            "guidance_scale": 6.0,
1572
            "output_type": "np",
1573
            "image": images,
1574
        }
1575
        return inputs
1576

1577
    def test_stable_diffusion_multicontrolnet_num_images_per_prompt(self):
1578
        components = self.get_dummy_components()
1579
        gaudi_config = GaudiConfig()
1580

1581
        sd_pipe = GaudiStableDiffusionControlNetPipeline(
1582
            use_habana=True,
1583
            gaudi_config=gaudi_config,
1584
            **components,
1585
        )
1586
        sd_pipe.set_progress_bar_config(disable=None)
1587

1588
        inputs = self.get_dummy_inputs(device="cpu")
1589
        prompt = inputs["prompt"]
1590
        # Test num_images_per_prompt=1 (default)
1591
        images = sd_pipe(**inputs).images
1592

1593
        self.assertEqual(len(images), 1)
1594
        self.assertEqual(images[0].shape, (64, 64, 3))
1595

1596
        # Test num_images_per_prompt=1 (default) for several prompts
1597
        num_prompts = 3
1598
        inputs["prompt"] = [prompt] * num_prompts
1599
        images = sd_pipe(**inputs).images
1600

1601
        self.assertEqual(len(images), num_prompts)
1602
        self.assertEqual(images[-1].shape, (64, 64, 3))
1603

1604
        # Test num_images_per_prompt for single prompt
1605
        num_images_per_prompt = 2
1606
        inputs["prompt"] = prompt
1607
        images = sd_pipe(num_images_per_prompt=num_images_per_prompt, **inputs).images
1608

1609
        self.assertEqual(len(images), num_images_per_prompt)
1610
        self.assertEqual(images[-1].shape, (64, 64, 3))
1611

1612
        ## Test num_images_per_prompt for several prompts
1613
        num_prompts = 2
1614
        inputs["prompt"] = [prompt] * num_prompts
1615
        images = sd_pipe(num_images_per_prompt=num_images_per_prompt, **inputs).images
1616

1617
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
1618
        self.assertEqual(images[-1].shape, (64, 64, 3))
1619

1620
    def test_stable_diffusion_multicontrolnet_batch_sizes(self):
1621
        components = self.get_dummy_components()
1622
        gaudi_config = GaudiConfig()
1623

1624
        sd_pipe = GaudiStableDiffusionControlNetPipeline(
1625
            use_habana=True,
1626
            gaudi_config=gaudi_config,
1627
            **components,
1628
        )
1629
        sd_pipe.set_progress_bar_config(disable=None)
1630

1631
        inputs = self.get_dummy_inputs(device="cpu")
1632
        prompt = inputs["prompt"]
1633
        # Test batch_size > 1 where batch_size is a divider of the total number of generated images
1634
        batch_size = 3
1635
        num_images_per_prompt = batch_size**2
1636
        images = sd_pipe(
1637
            batch_size=batch_size,
1638
            num_images_per_prompt=num_images_per_prompt,
1639
            **inputs,
1640
        ).images
1641
        self.assertEqual(len(images), num_images_per_prompt)
1642
        self.assertEqual(images[-1].shape, (64, 64, 3))
1643

1644
        # Same test for several prompts
1645
        num_prompts = 3
1646
        inputs["prompt"] = [prompt] * num_prompts
1647

1648
        images = sd_pipe(
1649
            batch_size=batch_size,
1650
            num_images_per_prompt=num_images_per_prompt,
1651
            **inputs,
1652
        ).images
1653

1654
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
1655
        self.assertEqual(images[-1].shape, (64, 64, 3))
1656

1657
        inputs["prompt"] = prompt
1658
        # Test batch_size when it is not a divider of the total number of generated images for a single prompt
1659
        num_images_per_prompt = 7
1660
        images = sd_pipe(
1661
            batch_size=batch_size,
1662
            num_images_per_prompt=num_images_per_prompt,
1663
            **inputs,
1664
        ).images
1665

1666
        self.assertEqual(len(images), num_images_per_prompt)
1667
        self.assertEqual(images[-1].shape, (64, 64, 3))
1668

1669
        # Same test for several prompts
1670
        num_prompts = 2
1671
        inputs["prompt"] = [prompt] * num_prompts
1672
        images = sd_pipe(batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, **inputs).images
1673

1674
        self.assertEqual(len(images), num_prompts * num_images_per_prompt)
1675
        self.assertEqual(images[-1].shape, (64, 64, 3))
1676

1677
    def test_stable_diffusion_multicontrolnet_bf16(self):
1678
        """Test that stable diffusion works with bf16"""
1679
        components = self.get_dummy_components()
1680
        gaudi_config = GaudiConfig()
1681

1682
        sd_pipe = GaudiStableDiffusionControlNetPipeline(
1683
            use_habana=True,
1684
            gaudi_config=gaudi_config,
1685
            **components,
1686
        )
1687
        sd_pipe.set_progress_bar_config(disable=None)
1688

1689
        inputs = self.get_dummy_inputs(device="cpu")
1690
        image = sd_pipe(**inputs).images[0]
1691

1692
        self.assertEqual(image.shape, (64, 64, 3))
1693

1694
    def test_stable_diffusion_multicontrolnet_default(self):
1695
        components = self.get_dummy_components()
1696

1697
        sd_pipe = GaudiStableDiffusionControlNetPipeline(
1698
            use_habana=True,
1699
            gaudi_config="Habana/stable-diffusion",
1700
            **components,
1701
        )
1702
        sd_pipe.set_progress_bar_config(disable=None)
1703

1704
        inputs = self.get_dummy_inputs(device="cpu")
1705
        inputs["prompt"] = [inputs["prompt"]] * 2
1706
        images = sd_pipe(
1707
            batch_size=3,
1708
            num_images_per_prompt=5,
1709
            **inputs,
1710
        ).images
1711

1712
        self.assertEqual(len(images), 10)
1713
        self.assertEqual(images[-1].shape, (64, 64, 3))
1714

1715
    def test_stable_diffusion_multicontrolnet_hpu_graphs(self):
1716
        components = self.get_dummy_components()
1717

1718
        sd_pipe = GaudiStableDiffusionControlNetPipeline(
1719
            use_habana=True,
1720
            use_hpu_graphs=True,
1721
            gaudi_config="Habana/stable-diffusion",
1722
            **components,
1723
        )
1724
        sd_pipe.set_progress_bar_config(disable=None)
1725

1726
        inputs = self.get_dummy_inputs(device="cpu")
1727
        inputs["prompt"] = [inputs["prompt"]] * 2
1728

1729
        images = sd_pipe(
1730
            batch_size=3,
1731
            num_images_per_prompt=5,
1732
            **inputs,
1733
        ).images
1734

1735
        self.assertEqual(len(images), 10)
1736
        self.assertEqual(images[-1].shape, (64, 64, 3))
1737

1738

1739
class TrainTextToImage(TestCase):
1740
    """
1741
    Tests the Stable Diffusion text_to_image Training for Gaudi.
1742
    """
1743

1744
    def test_train_text_to_image_script(self):
1745
        path_to_script = (
1746
            Path(os.path.dirname(__file__)).parent
1747
            / "examples"
1748
            / "stable-diffusion"
1749
            / "training"
1750
            / "train_text_to_image_sdxl.py"
1751
        )
1752

1753
        cmd_line = f"""ls {path_to_script}""".split()
1754

1755
        # check find existence
1756
        p = subprocess.Popen(cmd_line)
1757
        return_code = p.wait()
1758

1759
        # Ensure the run finished without any issue
1760
        self.assertEqual(return_code, 0)
1761

1762
    @slow
1763
    def test_train_text_to_image_sdxl(self):
1764
        with tempfile.TemporaryDirectory() as tmpdir:
1765
            path_to_script = (
1766
                Path(os.path.dirname(__file__)).parent
1767
                / "examples"
1768
                / "stable-diffusion"
1769
                / "training"
1770
                / "train_text_to_image_sdxl.py"
1771
            )
1772

1773
            cmd_line = f"""
1774
                 python3
1775
                 {path_to_script}
1776
                 --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0
1777
                 --pretrained_vae_model_name_or_path stabilityai/sdxl-vae
1778
                 --dataset_name lambdalabs/pokemon-blip-captions
1779
                 --resolution 64
1780
                 --center_crop
1781
                 --random_flip
1782
                 --proportion_empty_prompts=0.2
1783
                 --train_batch_size 1
1784
                 --gradient_accumulation_steps 4
1785
                 --learning_rate 1e-05
1786
                 --max_grad_norm 1
1787
                 --lr_scheduler constant
1788
                 --lr_warmup_steps 0
1789
                 --gaudi_config_name Habana/stable-diffusion
1790
                 --throughput_warmup_steps 3
1791
                 --use_hpu_graphs
1792
                 --bf16
1793
                 --max_train_steps 2
1794
                 --output_dir {tmpdir}
1795
                """.split()
1796

1797
            # Run train_text_to_image_sdxl.y
1798
            p = subprocess.Popen(cmd_line)
1799
            return_code = p.wait()
1800

1801
            # Ensure the run finished without any issue
1802
            self.assertEqual(return_code, 0)
1803

1804
            # save_pretrained smoke test
1805
            self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
1806
            self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
1807

1808
    @slow
1809
    def test_train_text_to_image_sdxl_lora(self):
1810
        with tempfile.TemporaryDirectory() as tmpdir:
1811
            path_to_script = (
1812
                Path(os.path.dirname(__file__)).parent
1813
                / "examples"
1814
                / "stable-diffusion"
1815
                / "training"
1816
                / "train_text_to_image_sdxl.py"
1817
            )
1818

1819
            cmd_line = f"""
1820
                 python3
1821
                 {path_to_script}
1822
                 --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0
1823
                 --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix
1824
                 --dataset_name=lambdalabs/pokemon-blip-captions
1825
                 --caption_column=text
1826
                 --resolution=64
1827
                 --random_flip
1828
                 --train_batch_size=1
1829
                 --learning_rate=1e-04
1830
                 --lr_scheduler=constant
1831
                 --lr_warmup_steps=0
1832
                 --seed=42
1833
                 --finetuning_method=lora
1834
                 --gaudi_config_name=Habana/stable-diffusion
1835
                 --throughput_warmup_steps=3
1836
                 --use_hpu_graphs
1837
                 --bf16
1838
                 --max_train_steps 2
1839
                 --output_dir {tmpdir}
1840
                """.split()
1841

1842
            # Run train_text_to_image_lora.py
1843
            p = subprocess.Popen(cmd_line)
1844
            return_code = p.wait()
1845

1846
            # Ensure the run finished without any issue
1847
            self.assertEqual(return_code, 0)
1848

1849
            # save_pretrained smoke test
1850
            self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
1851

1852

1853
class TrainControlNet(TestCase):
1854
    """
1855
    Tests the train_controlnet.py script for Gaudi.
1856
    """
1857

1858
    def test_train_controlnet_script(self):
1859
        path_to_script = (
1860
            Path(os.path.dirname(__file__)).parent
1861
            / "examples"
1862
            / "stable-diffusion"
1863
            / "training"
1864
            / "train_controlnet.py"
1865
        )
1866

1867
        cmd_line = f"""ls {path_to_script}""".split()
1868

1869
        # check find existence
1870
        p = subprocess.Popen(cmd_line)
1871
        return_code = p.wait()
1872

1873
        # Ensure the run finished without any issue
1874
        self.assertEqual(return_code, 0)
1875

1876
    @slow
1877
    def test_train_controlnet(self):
1878
        with tempfile.TemporaryDirectory() as tmpdir:
1879
            path_to_script = (
1880
                Path(os.path.dirname(__file__)).parent
1881
                / "examples"
1882
                / "stable-diffusion"
1883
                / "training"
1884
                / "train_controlnet.py"
1885
            )
1886

1887
            download_files(
1888
                [
1889
                    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png",
1890
                    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png",
1891
                ],
1892
                path=tmpdir,
1893
            )
1894

1895
            cmd_line = f"""
1896
                    python3
1897
                    {path_to_script.parent.parent.parent / 'gaudi_spawn.py'}
1898
                    --use_mpi
1899
                    --world_size 8
1900
                    {path_to_script}
1901
                    --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5
1902
                    --dataset_name fusing/fill50k
1903
                    --resolution 512
1904
                    --train_batch_size 4
1905
                    --learning_rate 1e-05
1906
                    --validation_steps 1000
1907
                    --validation_image "{tmpdir}/conditioning_image_1.png" "{tmpdir}/conditioning_image_2.png"
1908
                    --validation_prompt "red circle with blue background" "cyan circle with brown floral background"
1909
                    --checkpointing_steps 1000
1910
                    --throughput_warmup_steps 3
1911
                    --use_hpu_graphs
1912
                    --bf16
1913
                    --num_train_epochs 1
1914
                    --output_dir {tmpdir}
1915
                """.split()
1916

1917
            # Run train_controlnet.y
1918
            p = subprocess.Popen(cmd_line)
1919
            return_code = p.wait()
1920

1921
            # Ensure the run finished without any issue
1922
            self.assertEqual(return_code, 0)
1923

1924
            # Assess throughput
1925
            with open(Path(tmpdir) / "speed_metrics.json") as fp:
1926
                results = json.load(fp)
1927
            self.assertGreaterEqual(results["train_samples_per_second"], 0.95 * CONTROLNET_THROUGHPUT)
1928
            self.assertLessEqual(results["train_runtime"], 1.05 * CONTROLNET_RUNTIME)
1929

1930
            # Assess generated image
1931
            controlnet = ControlNetModel.from_pretrained(tmpdir, torch_dtype=torch.bfloat16)
1932
            pipe = GaudiStableDiffusionControlNetPipeline.from_pretrained(
1933
                "runwayml/stable-diffusion-v1-5",
1934
                controlnet=controlnet,
1935
                torch_dtype=torch.bfloat16,
1936
                use_habana=True,
1937
                use_hpu_graphs=True,
1938
                gaudi_config=GaudiConfig(use_habana_mixed_precision=False),
1939
            )
1940
            pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
1941

1942
            control_image = load_image(f"{tmpdir}/conditioning_image_1.png")
1943
            prompt = "pale golden rod circle with old lace background"
1944

1945
            generator = set_seed(27)
1946
            image = pipe(
1947
                prompt, num_inference_steps=20, generator=generator, image=control_image, output_type="np"
1948
            ).images[0]
1949

1950
            self.assertEqual(image.shape, (512, 512, 3))
1951

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.