stable-diffusion-webui
344 строки · 12.6 Кб
1import inspect2from collections import namedtuple3import numpy as np4import torch5from PIL import Image6from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models7from modules.shared import opts, state8import k_diffusion.sampling9
10
11SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])12
13
14class SamplerData(SamplerDataTuple):15def total_steps(self, steps):16if self.options.get("second_order", False):17steps = steps * 218
19return steps20
21
22def setup_img2img_steps(p, steps=None):23if opts.img2img_fix_steps or steps is not None:24requested_steps = (steps or p.steps)25steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 026t_enc = requested_steps - 127else:28steps = p.steps29t_enc = int(min(p.denoising_strength, 0.999) * steps)30
31return steps, t_enc32
33
34approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}35
36
37def samples_to_images_tensor(sample, approximation=None, model=None):38"""Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""39
40if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):41approximation = approximation_indexes.get(opts.show_progress_type, 0)42
43from modules import lowvram44if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:45approximation = 146
47if approximation == 2:48x_sample = sd_vae_approx.cheap_approximation(sample)49elif approximation == 1:50x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()51elif approximation == 3:52x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()53x_sample = x_sample * 2 - 154else:55if model is None:56model = shared.sd_model57with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp3258x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))59
60return x_sample61
62
63def single_sample_to_image(sample, approximation=None):64x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.565
66x_sample = torch.clamp(x_sample, min=0.0, max=1.0)67x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)68x_sample = x_sample.astype(np.uint8)69
70return Image.fromarray(x_sample)71
72
73def decode_first_stage(model, x):74x = x.to(devices.dtype_vae)75approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)76return samples_to_images_tensor(x, approx_index, model)77
78
79def sample_to_image(samples, index=0, approximation=None):80return single_sample_to_image(samples[index], approximation)81
82
83def samples_to_image_grid(samples, approximation=None):84return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])85
86
87def images_tensor_to_samples(image, approximation=None, model=None):88'''image[0, 1] -> latent'''89if approximation is None:90approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)91
92if approximation == 3:93image = image.to(devices.device, devices.dtype)94x_latent = sd_vae_taesd.encoder_model()(image)95else:96if model is None:97model = shared.sd_model98model.first_stage_model.to(devices.dtype_vae)99
100image = image.to(shared.device, dtype=devices.dtype_vae)101image = image * 2 - 1102if len(image) > 1:103x_latent = torch.stack([104model.get_first_stage_encoding(105model.encode_first_stage(torch.unsqueeze(img, 0))106)[0]107for img in image108])109else:110x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))111
112return x_latent113
114
115def store_latent(decoded):116state.current_latent = decoded117
118if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:119if not shared.parallel_processing_allowed:120shared.state.assign_current_image(sample_to_image(decoded))121
122
123def is_sampler_using_eta_noise_seed_delta(p):124"""returns whether sampler from config will use eta noise seed delta for image creation"""125
126sampler_config = sd_samplers.find_sampler_config(p.sampler_name)127
128eta = p.eta129
130if eta is None and p.sampler is not None:131eta = p.sampler.eta132
133if eta is None and sampler_config is not None:134eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0135
136if eta == 0:137return False138
139return sampler_config.options.get("uses_ensd", False)140
141
142class InterruptedException(BaseException):143pass144
145
146def replace_torchsde_browinan():147import torchsde._brownian.brownian_interval148
149def torchsde_randn(size, dtype, device, seed):150return devices.randn_local(seed, size).to(device=device, dtype=dtype)151
152torchsde._brownian.brownian_interval._randn = torchsde_randn153
154
155replace_torchsde_browinan()156
157
158def apply_refiner(cfg_denoiser):159completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps160refiner_switch_at = cfg_denoiser.p.refiner_switch_at161refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info162
163if refiner_switch_at is not None and completed_ratio < refiner_switch_at:164return False165
166if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:167return False168
169if getattr(cfg_denoiser.p, "enable_hr", False):170is_second_pass = cfg_denoiser.p.is_hr_pass171
172if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:173return False174
175if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:176return False177
178if opts.hires_fix_refiner_pass != "second pass":179cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass180
181cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title182cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at183
184with sd_models.SkipWritingToConfig():185sd_models.reload_model_weights(info=refiner_checkpoint_info)186
187devices.torch_gc()188cfg_denoiser.p.setup_conds()189cfg_denoiser.update_inner_model()190
191return True192
193
194class TorchHijack:195"""This is here to replace torch.randn_like of k-diffusion.196
197k-diffusion has random_sampler argument for most samplers, but not for all, so
198this is needed to properly replace every use of torch.randn_like.
199
200We need to replace to make images generated in batches to be same as images generated individually."""
201
202def __init__(self, p):203self.rng = p.rng204
205def __getattr__(self, item):206if item == 'randn_like':207return self.randn_like208
209if hasattr(torch, item):210return getattr(torch, item)211
212raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")213
214def randn_like(self, x):215return self.rng.next()216
217
218class Sampler:219def __init__(self, funcname):220self.funcname = funcname221self.func = funcname222self.extra_params = []223self.sampler_noises = None224self.stop_at = None225self.eta = None226self.config: SamplerData = None # set by the function calling the constructor227self.last_latent = None228self.s_min_uncond = None229self.s_churn = 0.0230self.s_tmin = 0.0231self.s_tmax = float('inf')232self.s_noise = 1.0233
234self.eta_option_field = 'eta_ancestral'235self.eta_infotext_field = 'Eta'236self.eta_default = 1.0237
238self.conditioning_key = shared.sd_model.model.conditioning_key239
240self.p = None241self.model_wrap_cfg = None242self.sampler_extra_args = None243self.options = {}244
245def callback_state(self, d):246step = d['i']247
248if self.stop_at is not None and step > self.stop_at:249raise InterruptedException250
251state.sampling_step = step252shared.total_tqdm.update()253
254def launch_sampling(self, steps, func):255self.model_wrap_cfg.steps = steps256self.model_wrap_cfg.total_steps = self.config.total_steps(steps)257state.sampling_steps = steps258state.sampling_step = 0259
260try:261return func()262except RecursionError:263print(264'Encountered RecursionError during sampling, returning last latent. '265'rho >5 with a polyexponential scheduler may cause this error. '266'You should try to use a smaller rho value instead.'267)268return self.last_latent269except InterruptedException:270return self.last_latent271
272def number_of_needed_noises(self, p):273return p.steps274
275def initialize(self, p) -> dict:276self.p = p277self.model_wrap_cfg.p = p278self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None279self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None280self.model_wrap_cfg.step = 0281self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)282self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)283self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)284
285k_diffusion.sampling.torch = TorchHijack(p)286
287extra_params_kwargs = {}288for param_name in self.extra_params:289if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:290extra_params_kwargs[param_name] = getattr(p, param_name)291
292if 'eta' in inspect.signature(self.func).parameters:293if self.eta != self.eta_default:294p.extra_generation_params[self.eta_infotext_field] = self.eta295
296extra_params_kwargs['eta'] = self.eta297
298if len(self.extra_params) > 0:299s_churn = getattr(opts, 's_churn', p.s_churn)300s_tmin = getattr(opts, 's_tmin', p.s_tmin)301s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf302s_noise = getattr(opts, 's_noise', p.s_noise)303
304if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:305extra_params_kwargs['s_churn'] = s_churn306p.s_churn = s_churn307p.extra_generation_params['Sigma churn'] = s_churn308if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:309extra_params_kwargs['s_tmin'] = s_tmin310p.s_tmin = s_tmin311p.extra_generation_params['Sigma tmin'] = s_tmin312if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:313extra_params_kwargs['s_tmax'] = s_tmax314p.s_tmax = s_tmax315p.extra_generation_params['Sigma tmax'] = s_tmax316if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:317extra_params_kwargs['s_noise'] = s_noise318p.s_noise = s_noise319p.extra_generation_params['Sigma noise'] = s_noise320
321return extra_params_kwargs322
323def create_noise_sampler(self, x, sigmas, p):324"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""325if shared.opts.no_dpmpp_sde_batch_determinism:326return None327
328from k_diffusion.sampling import BrownianTreeNoiseSampler329sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()330current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]331return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)332
333def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):334raise NotImplementedError()335
336def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):337raise NotImplementedError()338
339def add_infotext(self, p):340if self.model_wrap_cfg.padded_cond_uncond:341p.extra_generation_params["Pad conds"] = True342
343if self.model_wrap_cfg.padded_cond_uncond_v0:344p.extra_generation_params["Pad conds v0"] = True345