stable-diffusion-webui

Форк
0
/
sd_samplers_lcm.py 
104 строки · 3.8 Кб
1
import torch
2

3
from k_diffusion import utils, sampling
4
from k_diffusion.external import DiscreteEpsDDPMDenoiser
5
from k_diffusion.sampling import default_noise_sampler, trange
6

7
from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion, sd_samplers_common
8

9

10
class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
11
    def __init__(self, model):
12
        timesteps = 1000
13
        original_timesteps = 50     # LCM Original Timesteps (default=50, for current version of LCM)
14
        self.skip_steps = timesteps // original_timesteps
15

16
        alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device)
17
        for x in range(original_timesteps):
18
            alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps]
19

20
        super().__init__(model, alphas_cumprod_valid, quantize=None)
21

22

23
    def get_sigmas(self, n=None,):
24
        if n is None:
25
            return sampling.append_zero(self.sigmas.flip(0))
26

27
        start = self.sigma_to_t(self.sigma_max)
28
        end = self.sigma_to_t(self.sigma_min)
29

30
        t = torch.linspace(start, end, n, device=shared.sd_model.device)
31

32
        return sampling.append_zero(self.t_to_sigma(t))
33

34

35
    def sigma_to_t(self, sigma, quantize=None):
36
        log_sigma = sigma.log()
37
        dists = log_sigma - self.log_sigmas[:, None]
38
        return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
39

40

41
    def t_to_sigma(self, timestep):
42
        t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
43
        return super().t_to_sigma(t)
44

45

46
    def get_eps(self, *args, **kwargs):
47
        return self.inner_model.apply_model(*args, **kwargs)
48

49

50
    def get_scaled_out(self, sigma, output, input):
51
        sigma_data = 0.5
52
        scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0
53

54
        c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
55
        c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
56

57
        return c_out * output + c_skip * input
58

59

60
    def forward(self, input, sigma, **kwargs):
61
        c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
62
        eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
63
        return self.get_scaled_out(sigma, input + eps * c_out, input)
64

65

66
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
67
    extra_args = {} if extra_args is None else extra_args
68
    noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
69
    s_in = x.new_ones([x.shape[0]])
70

71
    for i in trange(len(sigmas) - 1, disable=disable):
72
        denoised = model(x, sigmas[i] * s_in, **extra_args)
73

74
        if callback is not None:
75
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
76

77
        x = denoised
78
        if sigmas[i + 1] > 0:
79
            x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
80
    return x
81

82

83
class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser):
84
    @property
85
    def inner_model(self):
86
        if self.model_wrap is None:
87
            denoiser = LCMCompVisDenoiser
88
            self.model_wrap = denoiser(shared.sd_model)
89

90
        return self.model_wrap
91

92

93
class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler):
94
    def __init__(self, funcname, sd_model, options=None):
95
        super().__init__(funcname, sd_model, options)
96
        self.model_wrap_cfg = CFGDenoiserLCM(self)
97
        self.model_wrap = self.model_wrap_cfg.inner_model
98

99

100
samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})]
101
samplers_data_lcm = [
102
    sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options)
103
    for label, funcname, aliases, options in samplers_lcm
104
]
105

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.