stable-diffusion-webui
391 строка · 16.7 Кб
1import torch2from torch.nn.functional import silu3from types import MethodType4
5from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches6from modules.hypernetworks import hypernetwork7from modules.shared import cmd_opts8from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m189
10import ldm.modules.attention11import ldm.modules.diffusionmodules.model12import ldm.modules.diffusionmodules.openaimodel13import ldm.models.diffusion.ddpm14import ldm.models.diffusion.ddim15import ldm.models.diffusion.plms16import ldm.modules.encoders.modules17
18import sgm.modules.attention19import sgm.modules.diffusionmodules.model20import sgm.modules.diffusionmodules.openaimodel21import sgm.modules.encoders.modules22
23attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward24diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity25diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward26
27# new memory efficient cross attention blocks do not support hypernets and we already
28# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
29ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention30ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention31
32# silence new console spam from SD2
33ldm.modules.attention.print = shared.ldm_print34ldm.modules.diffusionmodules.model.print = shared.ldm_print35ldm.util.print = shared.ldm_print36ldm.models.diffusion.ddpm.print = shared.ldm_print37
38optimizers = []39current_optimizer: sd_hijack_optimizations.SdOptimization = None40
41ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)42ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)43
44sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)45sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)46
47
48def list_optimizers():49new_optimizers = script_callbacks.list_optimizers_callback()50
51new_optimizers = [x for x in new_optimizers if x.is_available()]52
53new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)54
55optimizers.clear()56optimizers.extend(new_optimizers)57
58
59def apply_optimizations(option=None):60global current_optimizer61
62undo_optimizations()63
64if len(optimizers) == 0:65# a script can access the model very early, and optimizations would not be filled by then66current_optimizer = None67return ''68
69ldm.modules.diffusionmodules.model.nonlinearity = silu70ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th71
72sgm.modules.diffusionmodules.model.nonlinearity = silu73sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th74
75if current_optimizer is not None:76current_optimizer.undo()77current_optimizer = None78
79selection = option or shared.opts.cross_attention_optimization80if selection == "Automatic" and len(optimizers) > 0:81matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])82else:83matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)84
85if selection == "None":86matching_optimizer = None87elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:88matching_optimizer = None89elif matching_optimizer is None:90matching_optimizer = optimizers[0]91
92if matching_optimizer is not None:93print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')94matching_optimizer.apply()95print("done.")96current_optimizer = matching_optimizer97return current_optimizer.name98else:99print("Disabling attention optimization")100return ''101
102
103def undo_optimizations():104ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity105ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward106ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward107
108sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity109sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward110sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward111
112
113def fix_checkpoint():114"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want115checkpoints to be added when not training (there's a warning)"""
116
117pass118
119
120def weighted_loss(sd_model, pred, target, mean=True):121#Calculate the weight normally, but ignore the mean122loss = sd_model._old_get_loss(pred, target, mean=False)123
124#Check if we have weights available125weight = getattr(sd_model, '_custom_loss_weight', None)126if weight is not None:127loss *= weight128
129#Return the loss, as mean if specified130return loss.mean() if mean else loss131
132def weighted_forward(sd_model, x, c, w, *args, **kwargs):133try:134#Temporarily append weights to a place accessible during loss calc135sd_model._custom_loss_weight = w136
137#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely138#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set139if not hasattr(sd_model, '_old_get_loss'):140sd_model._old_get_loss = sd_model.get_loss141sd_model.get_loss = MethodType(weighted_loss, sd_model)142
143#Run the standard forward function, but with the patched 'get_loss'144return sd_model.forward(x, c, *args, **kwargs)145finally:146try:147#Delete temporary weights if appended148del sd_model._custom_loss_weight149except AttributeError:150pass151
152#If we have an old loss function, reset the loss function to the original one153if hasattr(sd_model, '_old_get_loss'):154sd_model.get_loss = sd_model._old_get_loss155del sd_model._old_get_loss156
157def apply_weighted_forward(sd_model):158#Add new function 'weighted_forward' that can be called to calc weighted loss159sd_model.weighted_forward = MethodType(weighted_forward, sd_model)160
161def undo_weighted_forward(sd_model):162try:163del sd_model.weighted_forward164except AttributeError:165pass166
167
168class StableDiffusionModelHijack:169fixes = None170layers = None171circular_enabled = False172clip = None173optimization_method = None174
175def __init__(self):176import modules.textual_inversion.textual_inversion177
178self.extra_generation_params = {}179self.comments = []180
181self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()182self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)183
184def apply_optimizations(self, option=None):185try:186self.optimization_method = apply_optimizations(option)187except Exception as e:188errors.display(e, "applying cross attention optimization")189undo_optimizations()190
191def convert_sdxl_to_ssd(self, m):192"""Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""193
194delattr(m.model.diffusion_model.middle_block, '1')195delattr(m.model.diffusion_model.middle_block, '2')196for i in ['9', '8', '7', '6', '5', '4']:197delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)198delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)199delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)200delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)201delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')202delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')203devices.torch_gc()204
205def hijack(self, m):206conditioner = getattr(m, 'conditioner', None)207if conditioner:208text_cond_models = []209
210for i in range(len(conditioner.embedders)):211embedder = conditioner.embedders[i]212typename = type(embedder).__name__213if typename == 'FrozenOpenCLIPEmbedder':214embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)215conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)216text_cond_models.append(conditioner.embedders[i])217if typename == 'FrozenCLIPEmbedder':218model_embeddings = embedder.transformer.text_model.embeddings219model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)220conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)221text_cond_models.append(conditioner.embedders[i])222if typename == 'FrozenOpenCLIPEmbedder2':223embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')224conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)225text_cond_models.append(conditioner.embedders[i])226
227if len(text_cond_models) == 1:228m.cond_stage_model = text_cond_models[0]229else:230m.cond_stage_model = conditioner231
232if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:233model_embeddings = m.cond_stage_model.roberta.embeddings234model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)235m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)236
237elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:238model_embeddings = m.cond_stage_model.transformer.text_model.embeddings239model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)240m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)241
242elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:243m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)244m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)245
246apply_weighted_forward(m)247if m.cond_stage_key == "edit":248sd_hijack_unet.hijack_ddpm_edit()249
250self.apply_optimizations()251
252self.clip = m.cond_stage_model253
254def flatten(el):255flattened = [flatten(children) for children in el.children()]256res = [el]257for c in flattened:258res += c259return res260
261self.layers = flatten(m)262
263import modules.models.diffusion.ddpm_edit264
265if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):266sd_unet.original_forward = ldm_original_forward267elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):268sd_unet.original_forward = ldm_original_forward269elif isinstance(m, sgm.models.diffusion.DiffusionEngine):270sd_unet.original_forward = sgm_original_forward271else:272sd_unet.original_forward = None273
274
275def undo_hijack(self, m):276conditioner = getattr(m, 'conditioner', None)277if conditioner:278for i in range(len(conditioner.embedders)):279embedder = conditioner.embedders[i]280if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):281embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped282conditioner.embedders[i] = embedder.wrapped283if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):284embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped285conditioner.embedders[i] = embedder.wrapped286
287if hasattr(m, 'cond_stage_model'):288delattr(m, 'cond_stage_model')289
290elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:291m.cond_stage_model = m.cond_stage_model.wrapped292
293elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:294m.cond_stage_model = m.cond_stage_model.wrapped295
296model_embeddings = m.cond_stage_model.transformer.text_model.embeddings297if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:298model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped299elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:300m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped301m.cond_stage_model = m.cond_stage_model.wrapped302
303undo_optimizations()304undo_weighted_forward(m)305
306self.apply_circular(False)307self.layers = None308self.clip = None309
310
311def apply_circular(self, enable):312if self.circular_enabled == enable:313return314
315self.circular_enabled = enable316
317for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:318layer.padding_mode = 'circular' if enable else 'zeros'319
320def clear_comments(self):321self.comments = []322self.extra_generation_params = {}323
324def get_prompt_lengths(self, text):325if self.clip is None:326return "-", "-"327
328_, token_count = self.clip.process_texts([text])329
330return token_count, self.clip.get_target_prompt_token_count(token_count)331
332def redo_hijack(self, m):333self.undo_hijack(m)334self.hijack(m)335
336
337class EmbeddingsWithFixes(torch.nn.Module):338def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):339super().__init__()340self.wrapped = wrapped341self.embeddings = embeddings342self.textual_inversion_key = textual_inversion_key343
344def forward(self, input_ids):345batch_fixes = self.embeddings.fixes346self.embeddings.fixes = None347
348inputs_embeds = self.wrapped(input_ids)349
350if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:351return inputs_embeds352
353vecs = []354for fixes, tensor in zip(batch_fixes, inputs_embeds):355for offset, embedding in fixes:356vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec357emb = devices.cond_cast_unet(vec)358emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])359tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])360
361vecs.append(tensor)362
363return torch.stack(vecs)364
365
366def add_circular_option_to_conv_2d():367conv2d_constructor = torch.nn.Conv2d.__init__368
369def conv2d_constructor_circular(self, *args, **kwargs):370return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)371
372torch.nn.Conv2d.__init__ = conv2d_constructor_circular373
374
375model_hijack = StableDiffusionModelHijack()376
377
378def register_buffer(self, name, attr):379"""380Fix register buffer bug for Mac OS.
381"""
382
383if type(attr) == torch.Tensor:384if attr.device != devices.device:385attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))386
387setattr(self, name, attr)388
389
390ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer391ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer392