stable-diffusion-webui

Форк
0
/
sd_samplers_extra.py 
74 строки · 3.1 Кб
1
import torch
2
import tqdm
3
import k_diffusion.sampling
4

5

6
@torch.no_grad()
7
def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
8
    """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
9
    Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
10
    If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
11
    """
12
    extra_args = {} if extra_args is None else extra_args
13
    s_in = x.new_ones([x.shape[0]])
14
    step_id = 0
15
    from k_diffusion.sampling import to_d, get_sigmas_karras
16

17
    def heun_step(x, old_sigma, new_sigma, second_order=True):
18
        nonlocal step_id
19
        denoised = model(x, old_sigma * s_in, **extra_args)
20
        d = to_d(x, old_sigma, denoised)
21
        if callback is not None:
22
            callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
23
        dt = new_sigma - old_sigma
24
        if new_sigma == 0 or not second_order:
25
            # Euler method
26
            x = x + d * dt
27
        else:
28
            # Heun's method
29
            x_2 = x + d * dt
30
            denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
31
            d_2 = to_d(x_2, new_sigma, denoised_2)
32
            d_prime = (d + d_2) / 2
33
            x = x + d_prime * dt
34
        step_id += 1
35
        return x
36

37
    steps = sigmas.shape[0] - 1
38
    if restart_list is None:
39
        if steps >= 20:
40
            restart_steps = 9
41
            restart_times = 1
42
            if steps >= 36:
43
                restart_steps = steps // 4
44
                restart_times = 2
45
            sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
46
            restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
47
        else:
48
            restart_list = {}
49

50
    restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}
51

52
    step_list = []
53
    for i in range(len(sigmas) - 1):
54
        step_list.append((sigmas[i], sigmas[i + 1]))
55
        if i + 1 in restart_list:
56
            restart_steps, restart_times, restart_max = restart_list[i + 1]
57
            min_idx = i + 1
58
            max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
59
            if max_idx < min_idx:
60
                sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
61
                while restart_times > 0:
62
                    restart_times -= 1
63
                    step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))
64

65
    last_sigma = None
66
    for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
67
        if last_sigma is None:
68
            last_sigma = old_sigma
69
        elif last_sigma < old_sigma:
70
            x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
71
        x = heun_step(x, old_sigma, new_sigma)
72
        last_sigma = new_sigma
73

74
    return x
75

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.