stable-diffusion-webui
298 строк · 13.5 Кб
1import torch2from modules import prompt_parser, devices, sd_samplers_common3
4from modules.shared import opts, state5import modules.shared as shared6from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback7from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback8from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback9
10
11def catenate_conds(conds):12if not isinstance(conds[0], dict):13return torch.cat(conds)14
15return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}16
17
18def subscript_cond(cond, a, b):19if not isinstance(cond, dict):20return cond[a:b]21
22return {key: vec[a:b] for key, vec in cond.items()}23
24
25def pad_cond(tensor, repeats, empty):26if not isinstance(tensor, dict):27return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)28
29tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)30return tensor31
32
33class CFGDenoiser(torch.nn.Module):34"""35Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
36that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
37instead of one. Originally, the second prompt is just an empty string, but we use non-empty
38negative prompt.
39"""
40
41def __init__(self, sampler):42super().__init__()43self.model_wrap = None44self.mask = None45self.nmask = None46self.init_latent = None47self.steps = None48"""number of steps as specified by user in UI"""49
50self.total_steps = None51"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""52
53self.step = 054self.image_cfg_scale = None55self.padded_cond_uncond = False56self.padded_cond_uncond_v0 = False57self.sampler = sampler58self.model_wrap = None59self.p = None60
61# NOTE: masking before denoising can cause the original latents to be oversmoothed62# as the original latents do not have noise63self.mask_before_denoising = False64
65@property66def inner_model(self):67raise NotImplementedError()68
69def combine_denoised(self, x_out, conds_list, uncond, cond_scale):70denoised_uncond = x_out[-uncond.shape[0]:]71denoised = torch.clone(denoised_uncond)72
73for i, conds in enumerate(conds_list):74for cond_index, weight in conds:75denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)76
77return denoised78
79def combine_denoised_for_edit_model(self, x_out, cond_scale):80out_cond, out_img_cond, out_uncond = x_out.chunk(3)81denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)82
83return denoised84
85def get_pred_x0(self, x_in, x_out, sigma):86return x_out87
88def update_inner_model(self):89self.model_wrap = None90
91c, uc = self.p.get_conds()92self.sampler.sampler_extra_args['cond'] = c93self.sampler.sampler_extra_args['uncond'] = uc94
95def pad_cond_uncond(self, cond, uncond):96empty = shared.sd_model.cond_stage_model_empty_prompt97num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]98
99if num_repeats < 0:100cond = pad_cond(cond, -num_repeats, empty)101self.padded_cond_uncond = True102elif num_repeats > 0:103uncond = pad_cond(uncond, num_repeats, empty)104self.padded_cond_uncond = True105
106return cond, uncond107
108def pad_cond_uncond_v0(self, cond, uncond):109"""110Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
111
112If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
113If 'uncond' is a tensor, it is padded directly.
114
115If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
116is repeated to match the number of columns in 'cond'.
117
118If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
119to match the number of columns in 'cond'.
120
121Args:
122cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
123uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
124
125Returns:
126tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
127
128Note:
129This is the padding that was always used in DDIM before version 1.6.0
130"""
131
132is_dict_cond = isinstance(uncond, dict)133uncond_vec = uncond['crossattn'] if is_dict_cond else uncond134
135if uncond_vec.shape[1] < cond.shape[1]:136last_vector = uncond_vec[:, -1:]137last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])138uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])139self.padded_cond_uncond_v0 = True140elif uncond_vec.shape[1] > cond.shape[1]:141uncond_vec = uncond_vec[:, :cond.shape[1]]142self.padded_cond_uncond_v0 = True143
144if is_dict_cond:145uncond['crossattn'] = uncond_vec146else:147uncond = uncond_vec148
149return cond, uncond150
151def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):152if state.interrupted or state.skipped:153raise sd_samplers_common.InterruptedException154
155if sd_samplers_common.apply_refiner(self):156cond = self.sampler.sampler_extra_args['cond']157uncond = self.sampler.sampler_extra_args['uncond']158
159# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,160# so is_edit_model is set to False to support AND composition.161is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0162
163conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)164uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)165
166assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"167
168# If we use masks, blending between the denoised and original latent images occurs here.169def apply_blend(current_latent):170blended_latent = current_latent * self.nmask + self.init_latent * self.mask171
172if self.p.scripts is not None:173from modules import scripts174mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)175self.p.scripts.on_mask_blend(self.p, mba)176blended_latent = mba.blended_latent177
178return blended_latent179
180# Blend in the original latents (before)181if self.mask_before_denoising and self.mask is not None:182x = apply_blend(x)183
184batch_size = len(conds_list)185repeats = [len(conds_list[i]) for i in range(batch_size)]186
187if shared.sd_model.model.conditioning_key == "crossattn-adm":188image_uncond = torch.zeros_like(image_cond)189make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}190else:191image_uncond = image_cond192if isinstance(uncond, dict):193make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}194else:195make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}196
197if not is_edit_model:198x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])199sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])200image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])201else:202x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])203sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])204image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])205
206denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)207cfg_denoiser_callback(denoiser_params)208x_in = denoiser_params.x209image_cond_in = denoiser_params.image_cond210sigma_in = denoiser_params.sigma211tensor = denoiser_params.text_cond212uncond = denoiser_params.text_uncond213skip_uncond = False214
215# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it216if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:217skip_uncond = True218x_in = x_in[:-batch_size]219sigma_in = sigma_in[:-batch_size]220
221self.padded_cond_uncond = False222self.padded_cond_uncond_v0 = False223if shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:224tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)225elif shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:226tensor, uncond = self.pad_cond_uncond(tensor, uncond)227
228if tensor.shape[1] == uncond.shape[1] or skip_uncond:229if is_edit_model:230cond_in = catenate_conds([tensor, uncond, uncond])231elif skip_uncond:232cond_in = tensor233else:234cond_in = catenate_conds([tensor, uncond])235
236if shared.opts.batch_cond_uncond:237x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))238else:239x_out = torch.zeros_like(x_in)240for batch_offset in range(0, x_out.shape[0], batch_size):241a = batch_offset242b = a + batch_size243x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))244else:245x_out = torch.zeros_like(x_in)246batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size247for batch_offset in range(0, tensor.shape[0], batch_size):248a = batch_offset249b = min(a + batch_size, tensor.shape[0])250
251if not is_edit_model:252c_crossattn = subscript_cond(tensor, a, b)253else:254c_crossattn = torch.cat([tensor[a:b]], uncond)255
256x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))257
258if not skip_uncond:259x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))260
261denoised_image_indexes = [x[0][0] for x in conds_list]262if skip_uncond:263fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])264x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be265
266denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)267cfg_denoised_callback(denoised_params)268
269devices.test_for_nans(x_out, "unet")270
271if is_edit_model:272denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)273elif skip_uncond:274denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)275else:276denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)277
278# Blend in the original latents (after)279if not self.mask_before_denoising and self.mask is not None:280denoised = apply_blend(denoised)281
282self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)283
284if opts.live_preview_content == "Prompt":285preview = self.sampler.last_latent286elif opts.live_preview_content == "Negative prompt":287preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)288else:289preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)290
291sd_samplers_common.store_latent(preview)292
293after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)294cfg_after_cfg_callback(after_cfg_callback_params)295denoised = after_cfg_callback_params.x296
297self.step += 1298return denoised299
300