stable-diffusion-webui

Форк
0
/
sd_samplers_common.py 
344 строки · 12.6 Кб
1
import inspect
2
from collections import namedtuple
3
import numpy as np
4
import torch
5
from PIL import Image
6
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
7
from modules.shared import opts, state
8
import k_diffusion.sampling
9

10

11
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
12

13

14
class SamplerData(SamplerDataTuple):
15
    def total_steps(self, steps):
16
        if self.options.get("second_order", False):
17
            steps = steps * 2
18

19
        return steps
20

21

22
def setup_img2img_steps(p, steps=None):
23
    if opts.img2img_fix_steps or steps is not None:
24
        requested_steps = (steps or p.steps)
25
        steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
26
        t_enc = requested_steps - 1
27
    else:
28
        steps = p.steps
29
        t_enc = int(min(p.denoising_strength, 0.999) * steps)
30

31
    return steps, t_enc
32

33

34
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
35

36

37
def samples_to_images_tensor(sample, approximation=None, model=None):
38
    """Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
39

40
    if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
41
        approximation = approximation_indexes.get(opts.show_progress_type, 0)
42

43
        from modules import lowvram
44
        if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
45
            approximation = 1
46

47
    if approximation == 2:
48
        x_sample = sd_vae_approx.cheap_approximation(sample)
49
    elif approximation == 1:
50
        x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
51
    elif approximation == 3:
52
        x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
53
        x_sample = x_sample * 2 - 1
54
    else:
55
        if model is None:
56
            model = shared.sd_model
57
        with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
58
            x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
59

60
    return x_sample
61

62

63
def single_sample_to_image(sample, approximation=None):
64
    x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
65

66
    x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
67
    x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
68
    x_sample = x_sample.astype(np.uint8)
69

70
    return Image.fromarray(x_sample)
71

72

73
def decode_first_stage(model, x):
74
    x = x.to(devices.dtype_vae)
75
    approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
76
    return samples_to_images_tensor(x, approx_index, model)
77

78

79
def sample_to_image(samples, index=0, approximation=None):
80
    return single_sample_to_image(samples[index], approximation)
81

82

83
def samples_to_image_grid(samples, approximation=None):
84
    return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
85

86

87
def images_tensor_to_samples(image, approximation=None, model=None):
88
    '''image[0, 1] -> latent'''
89
    if approximation is None:
90
        approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
91

92
    if approximation == 3:
93
        image = image.to(devices.device, devices.dtype)
94
        x_latent = sd_vae_taesd.encoder_model()(image)
95
    else:
96
        if model is None:
97
            model = shared.sd_model
98
        model.first_stage_model.to(devices.dtype_vae)
99

100
        image = image.to(shared.device, dtype=devices.dtype_vae)
101
        image = image * 2 - 1
102
        if len(image) > 1:
103
            x_latent = torch.stack([
104
                model.get_first_stage_encoding(
105
                    model.encode_first_stage(torch.unsqueeze(img, 0))
106
                )[0]
107
                for img in image
108
            ])
109
        else:
110
            x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
111

112
    return x_latent
113

114

115
def store_latent(decoded):
116
    state.current_latent = decoded
117

118
    if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
119
        if not shared.parallel_processing_allowed:
120
            shared.state.assign_current_image(sample_to_image(decoded))
121

122

123
def is_sampler_using_eta_noise_seed_delta(p):
124
    """returns whether sampler from config will use eta noise seed delta for image creation"""
125

126
    sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
127

128
    eta = p.eta
129

130
    if eta is None and p.sampler is not None:
131
        eta = p.sampler.eta
132

133
    if eta is None and sampler_config is not None:
134
        eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
135

136
    if eta == 0:
137
        return False
138

139
    return sampler_config.options.get("uses_ensd", False)
140

141

142
class InterruptedException(BaseException):
143
    pass
144

145

146
def replace_torchsde_browinan():
147
    import torchsde._brownian.brownian_interval
148

149
    def torchsde_randn(size, dtype, device, seed):
150
        return devices.randn_local(seed, size).to(device=device, dtype=dtype)
151

152
    torchsde._brownian.brownian_interval._randn = torchsde_randn
153

154

155
replace_torchsde_browinan()
156

157

158
def apply_refiner(cfg_denoiser):
159
    completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
160
    refiner_switch_at = cfg_denoiser.p.refiner_switch_at
161
    refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
162

163
    if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
164
        return False
165

166
    if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
167
        return False
168

169
    if getattr(cfg_denoiser.p, "enable_hr", False):
170
        is_second_pass = cfg_denoiser.p.is_hr_pass
171

172
        if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
173
            return False
174

175
        if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
176
            return False
177

178
        if opts.hires_fix_refiner_pass != "second pass":
179
            cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
180

181
    cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
182
    cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
183

184
    with sd_models.SkipWritingToConfig():
185
        sd_models.reload_model_weights(info=refiner_checkpoint_info)
186

187
    devices.torch_gc()
188
    cfg_denoiser.p.setup_conds()
189
    cfg_denoiser.update_inner_model()
190

191
    return True
192

193

194
class TorchHijack:
195
    """This is here to replace torch.randn_like of k-diffusion.
196

197
    k-diffusion has random_sampler argument for most samplers, but not for all, so
198
    this is needed to properly replace every use of torch.randn_like.
199

200
    We need to replace to make images generated in batches to be same as images generated individually."""
201

202
    def __init__(self, p):
203
        self.rng = p.rng
204

205
    def __getattr__(self, item):
206
        if item == 'randn_like':
207
            return self.randn_like
208

209
        if hasattr(torch, item):
210
            return getattr(torch, item)
211

212
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
213

214
    def randn_like(self, x):
215
        return self.rng.next()
216

217

218
class Sampler:
219
    def __init__(self, funcname):
220
        self.funcname = funcname
221
        self.func = funcname
222
        self.extra_params = []
223
        self.sampler_noises = None
224
        self.stop_at = None
225
        self.eta = None
226
        self.config: SamplerData = None  # set by the function calling the constructor
227
        self.last_latent = None
228
        self.s_min_uncond = None
229
        self.s_churn = 0.0
230
        self.s_tmin = 0.0
231
        self.s_tmax = float('inf')
232
        self.s_noise = 1.0
233

234
        self.eta_option_field = 'eta_ancestral'
235
        self.eta_infotext_field = 'Eta'
236
        self.eta_default = 1.0
237

238
        self.conditioning_key = shared.sd_model.model.conditioning_key
239

240
        self.p = None
241
        self.model_wrap_cfg = None
242
        self.sampler_extra_args = None
243
        self.options = {}
244

245
    def callback_state(self, d):
246
        step = d['i']
247

248
        if self.stop_at is not None and step > self.stop_at:
249
            raise InterruptedException
250

251
        state.sampling_step = step
252
        shared.total_tqdm.update()
253

254
    def launch_sampling(self, steps, func):
255
        self.model_wrap_cfg.steps = steps
256
        self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
257
        state.sampling_steps = steps
258
        state.sampling_step = 0
259

260
        try:
261
            return func()
262
        except RecursionError:
263
            print(
264
                'Encountered RecursionError during sampling, returning last latent. '
265
                'rho >5 with a polyexponential scheduler may cause this error. '
266
                'You should try to use a smaller rho value instead.'
267
            )
268
            return self.last_latent
269
        except InterruptedException:
270
            return self.last_latent
271

272
    def number_of_needed_noises(self, p):
273
        return p.steps
274

275
    def initialize(self, p) -> dict:
276
        self.p = p
277
        self.model_wrap_cfg.p = p
278
        self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
279
        self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
280
        self.model_wrap_cfg.step = 0
281
        self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
282
        self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
283
        self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
284

285
        k_diffusion.sampling.torch = TorchHijack(p)
286

287
        extra_params_kwargs = {}
288
        for param_name in self.extra_params:
289
            if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
290
                extra_params_kwargs[param_name] = getattr(p, param_name)
291

292
        if 'eta' in inspect.signature(self.func).parameters:
293
            if self.eta != self.eta_default:
294
                p.extra_generation_params[self.eta_infotext_field] = self.eta
295

296
            extra_params_kwargs['eta'] = self.eta
297

298
        if len(self.extra_params) > 0:
299
            s_churn = getattr(opts, 's_churn', p.s_churn)
300
            s_tmin = getattr(opts, 's_tmin', p.s_tmin)
301
            s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
302
            s_noise = getattr(opts, 's_noise', p.s_noise)
303

304
            if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
305
                extra_params_kwargs['s_churn'] = s_churn
306
                p.s_churn = s_churn
307
                p.extra_generation_params['Sigma churn'] = s_churn
308
            if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
309
                extra_params_kwargs['s_tmin'] = s_tmin
310
                p.s_tmin = s_tmin
311
                p.extra_generation_params['Sigma tmin'] = s_tmin
312
            if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
313
                extra_params_kwargs['s_tmax'] = s_tmax
314
                p.s_tmax = s_tmax
315
                p.extra_generation_params['Sigma tmax'] = s_tmax
316
            if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
317
                extra_params_kwargs['s_noise'] = s_noise
318
                p.s_noise = s_noise
319
                p.extra_generation_params['Sigma noise'] = s_noise
320

321
        return extra_params_kwargs
322

323
    def create_noise_sampler(self, x, sigmas, p):
324
        """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
325
        if shared.opts.no_dpmpp_sde_batch_determinism:
326
            return None
327

328
        from k_diffusion.sampling import BrownianTreeNoiseSampler
329
        sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
330
        current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
331
        return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
332

333
    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
334
        raise NotImplementedError()
335

336
    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
337
        raise NotImplementedError()
338

339
    def add_infotext(self, p):
340
        if self.model_wrap_cfg.padded_cond_uncond:
341
            p.extra_generation_params["Pad conds"] = True
342

343
        if self.model_wrap_cfg.padded_cond_uncond_v0:
344
            p.extra_generation_params["Pad conds v0"] = True
345

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.