stable-diffusion-webui

Форк
0
391 строка · 16.7 Кб
1
import torch
2
from torch.nn.functional import silu
3
from types import MethodType
4

5
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
6
from modules.hypernetworks import hypernetwork
7
from modules.shared import cmd_opts
8
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
9

10
import ldm.modules.attention
11
import ldm.modules.diffusionmodules.model
12
import ldm.modules.diffusionmodules.openaimodel
13
import ldm.models.diffusion.ddpm
14
import ldm.models.diffusion.ddim
15
import ldm.models.diffusion.plms
16
import ldm.modules.encoders.modules
17

18
import sgm.modules.attention
19
import sgm.modules.diffusionmodules.model
20
import sgm.modules.diffusionmodules.openaimodel
21
import sgm.modules.encoders.modules
22

23
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
24
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
25
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
26

27
# new memory efficient cross attention blocks do not support hypernets and we already
28
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
29
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
30
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
31

32
# silence new console spam from SD2
33
ldm.modules.attention.print = shared.ldm_print
34
ldm.modules.diffusionmodules.model.print = shared.ldm_print
35
ldm.util.print = shared.ldm_print
36
ldm.models.diffusion.ddpm.print = shared.ldm_print
37

38
optimizers = []
39
current_optimizer: sd_hijack_optimizations.SdOptimization = None
40

41
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
42
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
43

44
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
45
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
46

47

48
def list_optimizers():
49
    new_optimizers = script_callbacks.list_optimizers_callback()
50

51
    new_optimizers = [x for x in new_optimizers if x.is_available()]
52

53
    new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
54

55
    optimizers.clear()
56
    optimizers.extend(new_optimizers)
57

58

59
def apply_optimizations(option=None):
60
    global current_optimizer
61

62
    undo_optimizations()
63

64
    if len(optimizers) == 0:
65
        # a script can access the model very early, and optimizations would not be filled by then
66
        current_optimizer = None
67
        return ''
68

69
    ldm.modules.diffusionmodules.model.nonlinearity = silu
70
    ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
71

72
    sgm.modules.diffusionmodules.model.nonlinearity = silu
73
    sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
74

75
    if current_optimizer is not None:
76
        current_optimizer.undo()
77
        current_optimizer = None
78

79
    selection = option or shared.opts.cross_attention_optimization
80
    if selection == "Automatic" and len(optimizers) > 0:
81
        matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
82
    else:
83
        matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
84

85
    if selection == "None":
86
        matching_optimizer = None
87
    elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
88
        matching_optimizer = None
89
    elif matching_optimizer is None:
90
        matching_optimizer = optimizers[0]
91

92
    if matching_optimizer is not None:
93
        print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
94
        matching_optimizer.apply()
95
        print("done.")
96
        current_optimizer = matching_optimizer
97
        return current_optimizer.name
98
    else:
99
        print("Disabling attention optimization")
100
        return ''
101

102

103
def undo_optimizations():
104
    ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
105
    ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
106
    ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
107

108
    sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
109
    sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
110
    sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
111

112

113
def fix_checkpoint():
114
    """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
115
    checkpoints to be added when not training (there's a warning)"""
116

117
    pass
118

119

120
def weighted_loss(sd_model, pred, target, mean=True):
121
    #Calculate the weight normally, but ignore the mean
122
    loss = sd_model._old_get_loss(pred, target, mean=False)
123

124
    #Check if we have weights available
125
    weight = getattr(sd_model, '_custom_loss_weight', None)
126
    if weight is not None:
127
        loss *= weight
128

129
    #Return the loss, as mean if specified
130
    return loss.mean() if mean else loss
131

132
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
133
    try:
134
        #Temporarily append weights to a place accessible during loss calc
135
        sd_model._custom_loss_weight = w
136

137
        #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
138
        #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
139
        if not hasattr(sd_model, '_old_get_loss'):
140
            sd_model._old_get_loss = sd_model.get_loss
141
        sd_model.get_loss = MethodType(weighted_loss, sd_model)
142

143
        #Run the standard forward function, but with the patched 'get_loss'
144
        return sd_model.forward(x, c, *args, **kwargs)
145
    finally:
146
        try:
147
            #Delete temporary weights if appended
148
            del sd_model._custom_loss_weight
149
        except AttributeError:
150
            pass
151

152
        #If we have an old loss function, reset the loss function to the original one
153
        if hasattr(sd_model, '_old_get_loss'):
154
            sd_model.get_loss = sd_model._old_get_loss
155
            del sd_model._old_get_loss
156

157
def apply_weighted_forward(sd_model):
158
    #Add new function 'weighted_forward' that can be called to calc weighted loss
159
    sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
160

161
def undo_weighted_forward(sd_model):
162
    try:
163
        del sd_model.weighted_forward
164
    except AttributeError:
165
        pass
166

167

168
class StableDiffusionModelHijack:
169
    fixes = None
170
    layers = None
171
    circular_enabled = False
172
    clip = None
173
    optimization_method = None
174

175
    def __init__(self):
176
        import modules.textual_inversion.textual_inversion
177

178
        self.extra_generation_params = {}
179
        self.comments = []
180

181
        self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
182
        self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
183

184
    def apply_optimizations(self, option=None):
185
        try:
186
            self.optimization_method = apply_optimizations(option)
187
        except Exception as e:
188
            errors.display(e, "applying cross attention optimization")
189
            undo_optimizations()
190

191
    def convert_sdxl_to_ssd(self, m):
192
        """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
193

194
        delattr(m.model.diffusion_model.middle_block, '1')
195
        delattr(m.model.diffusion_model.middle_block, '2')
196
        for i in ['9', '8', '7', '6', '5', '4']:
197
            delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
198
            delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
199
            delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
200
            delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
201
        delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
202
        delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
203
        devices.torch_gc()
204

205
    def hijack(self, m):
206
        conditioner = getattr(m, 'conditioner', None)
207
        if conditioner:
208
            text_cond_models = []
209

210
            for i in range(len(conditioner.embedders)):
211
                embedder = conditioner.embedders[i]
212
                typename = type(embedder).__name__
213
                if typename == 'FrozenOpenCLIPEmbedder':
214
                    embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
215
                    conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
216
                    text_cond_models.append(conditioner.embedders[i])
217
                if typename == 'FrozenCLIPEmbedder':
218
                    model_embeddings = embedder.transformer.text_model.embeddings
219
                    model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
220
                    conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
221
                    text_cond_models.append(conditioner.embedders[i])
222
                if typename == 'FrozenOpenCLIPEmbedder2':
223
                    embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
224
                    conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
225
                    text_cond_models.append(conditioner.embedders[i])
226

227
            if len(text_cond_models) == 1:
228
                m.cond_stage_model = text_cond_models[0]
229
            else:
230
                m.cond_stage_model = conditioner
231

232
        if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
233
            model_embeddings = m.cond_stage_model.roberta.embeddings
234
            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
235
            m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
236

237
        elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
238
            model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
239
            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
240
            m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
241

242
        elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
243
            m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
244
            m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
245

246
        apply_weighted_forward(m)
247
        if m.cond_stage_key == "edit":
248
            sd_hijack_unet.hijack_ddpm_edit()
249

250
        self.apply_optimizations()
251

252
        self.clip = m.cond_stage_model
253

254
        def flatten(el):
255
            flattened = [flatten(children) for children in el.children()]
256
            res = [el]
257
            for c in flattened:
258
                res += c
259
            return res
260

261
        self.layers = flatten(m)
262

263
        import modules.models.diffusion.ddpm_edit
264

265
        if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
266
            sd_unet.original_forward = ldm_original_forward
267
        elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
268
            sd_unet.original_forward = ldm_original_forward
269
        elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
270
            sd_unet.original_forward = sgm_original_forward
271
        else:
272
            sd_unet.original_forward = None
273

274

275
    def undo_hijack(self, m):
276
        conditioner = getattr(m, 'conditioner', None)
277
        if conditioner:
278
            for i in range(len(conditioner.embedders)):
279
                embedder = conditioner.embedders[i]
280
                if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
281
                    embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
282
                    conditioner.embedders[i] = embedder.wrapped
283
                if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
284
                    embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
285
                    conditioner.embedders[i] = embedder.wrapped
286

287
            if hasattr(m, 'cond_stage_model'):
288
                delattr(m, 'cond_stage_model')
289

290
        elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
291
            m.cond_stage_model = m.cond_stage_model.wrapped
292

293
        elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
294
            m.cond_stage_model = m.cond_stage_model.wrapped
295

296
            model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
297
            if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
298
                model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
299
        elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
300
            m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
301
            m.cond_stage_model = m.cond_stage_model.wrapped
302

303
        undo_optimizations()
304
        undo_weighted_forward(m)
305

306
        self.apply_circular(False)
307
        self.layers = None
308
        self.clip = None
309

310

311
    def apply_circular(self, enable):
312
        if self.circular_enabled == enable:
313
            return
314

315
        self.circular_enabled = enable
316

317
        for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
318
            layer.padding_mode = 'circular' if enable else 'zeros'
319

320
    def clear_comments(self):
321
        self.comments = []
322
        self.extra_generation_params = {}
323

324
    def get_prompt_lengths(self, text):
325
        if self.clip is None:
326
            return "-", "-"
327

328
        _, token_count = self.clip.process_texts([text])
329

330
        return token_count, self.clip.get_target_prompt_token_count(token_count)
331

332
    def redo_hijack(self, m):
333
        self.undo_hijack(m)
334
        self.hijack(m)
335

336

337
class EmbeddingsWithFixes(torch.nn.Module):
338
    def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
339
        super().__init__()
340
        self.wrapped = wrapped
341
        self.embeddings = embeddings
342
        self.textual_inversion_key = textual_inversion_key
343

344
    def forward(self, input_ids):
345
        batch_fixes = self.embeddings.fixes
346
        self.embeddings.fixes = None
347

348
        inputs_embeds = self.wrapped(input_ids)
349

350
        if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
351
            return inputs_embeds
352

353
        vecs = []
354
        for fixes, tensor in zip(batch_fixes, inputs_embeds):
355
            for offset, embedding in fixes:
356
                vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
357
                emb = devices.cond_cast_unet(vec)
358
                emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
359
                tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
360

361
            vecs.append(tensor)
362

363
        return torch.stack(vecs)
364

365

366
def add_circular_option_to_conv_2d():
367
    conv2d_constructor = torch.nn.Conv2d.__init__
368

369
    def conv2d_constructor_circular(self, *args, **kwargs):
370
        return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
371

372
    torch.nn.Conv2d.__init__ = conv2d_constructor_circular
373

374

375
model_hijack = StableDiffusionModelHijack()
376

377

378
def register_buffer(self, name, attr):
379
    """
380
    Fix register buffer bug for Mac OS.
381
    """
382

383
    if type(attr) == torch.Tensor:
384
        if attr.device != devices.device:
385
            attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
386

387
    setattr(self, name, attr)
388

389

390
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
391
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
392

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.