stable-diffusion-webui

Форк
0
/
sd_samplers_timesteps_impl.py 
137 строк · 5.6 Кб
1
import torch
2
import tqdm
3
import k_diffusion.sampling
4
import numpy as np
5

6
from modules import shared
7
from modules.models.diffusion.uni_pc import uni_pc
8

9

10
@torch.no_grad()
11
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
12
    alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
13
    alphas = alphas_cumprod[timesteps]
14
    alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
15
    sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
16
    sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
17

18
    extra_args = {} if extra_args is None else extra_args
19
    s_in = x.new_ones((x.shape[0]))
20
    s_x = x.new_ones((x.shape[0], 1, 1, 1))
21
    for i in tqdm.trange(len(timesteps) - 1, disable=disable):
22
        index = len(timesteps) - 1 - i
23

24
        e_t = model(x, timesteps[index].item() * s_in, **extra_args)
25

26
        a_t = alphas[index].item() * s_x
27
        a_prev = alphas_prev[index].item() * s_x
28
        sigma_t = sigmas[index].item() * s_x
29
        sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
30

31
        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
32
        dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
33
        noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
34
        x = a_prev.sqrt() * pred_x0 + dir_xt + noise
35

36
        if callback is not None:
37
            callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
38

39
    return x
40

41

42
@torch.no_grad()
43
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
44
    alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
45
    alphas = alphas_cumprod[timesteps]
46
    alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
47
    sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
48

49
    extra_args = {} if extra_args is None else extra_args
50
    s_in = x.new_ones([x.shape[0]])
51
    s_x = x.new_ones((x.shape[0], 1, 1, 1))
52
    old_eps = []
53

54
    def get_x_prev_and_pred_x0(e_t, index):
55
        # select parameters corresponding to the currently considered timestep
56
        a_t = alphas[index].item() * s_x
57
        a_prev = alphas_prev[index].item() * s_x
58
        sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
59

60
        # current prediction for x_0
61
        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
62

63
        # direction pointing to x_t
64
        dir_xt = (1. - a_prev).sqrt() * e_t
65
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt
66
        return x_prev, pred_x0
67

68
    for i in tqdm.trange(len(timesteps) - 1, disable=disable):
69
        index = len(timesteps) - 1 - i
70
        ts = timesteps[index].item() * s_in
71
        t_next = timesteps[max(index - 1, 0)].item() * s_in
72

73
        e_t = model(x, ts, **extra_args)
74

75
        if len(old_eps) == 0:
76
            # Pseudo Improved Euler (2nd order)
77
            x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
78
            e_t_next = model(x_prev, t_next, **extra_args)
79
            e_t_prime = (e_t + e_t_next) / 2
80
        elif len(old_eps) == 1:
81
            # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
82
            e_t_prime = (3 * e_t - old_eps[-1]) / 2
83
        elif len(old_eps) == 2:
84
            # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
85
            e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
86
        else:
87
            # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
88
            e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
89

90
        x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
91

92
        old_eps.append(e_t)
93
        if len(old_eps) >= 4:
94
            old_eps.pop(0)
95

96
        x = x_prev
97

98
        if callback is not None:
99
            callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
100

101
    return x
102

103

104
class UniPCCFG(uni_pc.UniPC):
105
    def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
106
        super().__init__(None, *args, **kwargs)
107

108
        def after_update(x, model_x):
109
            callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
110
            self.index += 1
111

112
        self.cfg_model = cfg_model
113
        self.extra_args = extra_args
114
        self.callback = callback
115
        self.index = 0
116
        self.after_update = after_update
117

118
    def get_model_input_time(self, t_continuous):
119
        return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
120

121
    def model(self, x, t):
122
        t_input = self.get_model_input_time(t)
123

124
        res = self.cfg_model(x, t_input, **self.extra_args)
125

126
        return res
127

128

129
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
130
    alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
131

132
    ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
133
    t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None  # this is likely off by a bit - if someone wants to fix it please by all means
134
    unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
135
    x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
136

137
    return x
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.