stable-diffusion-webui

Форк
0
170 строк · 6.4 Кб
1
import torch
2

3
from modules import devices, rng_philox, shared
4

5

6
def randn(seed, shape, generator=None):
7
    """Generate a tensor with random numbers from a normal distribution using seed.
8

9
    Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
10

11
    manual_seed(seed)
12

13
    if shared.opts.randn_source == "NV":
14
        return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
15

16
    if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
17
        return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
18

19
    return torch.randn(shape, device=devices.device, generator=generator)
20

21

22
def randn_local(seed, shape):
23
    """Generate a tensor with random numbers from a normal distribution using seed.
24

25
    Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
26

27
    if shared.opts.randn_source == "NV":
28
        rng = rng_philox.Generator(seed)
29
        return torch.asarray(rng.randn(shape), device=devices.device)
30

31
    local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
32
    local_generator = torch.Generator(local_device).manual_seed(int(seed))
33
    return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
34

35

36
def randn_like(x):
37
    """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
38

39
    Use either randn() or manual_seed() to initialize the generator."""
40

41
    if shared.opts.randn_source == "NV":
42
        return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
43

44
    if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
45
        return torch.randn_like(x, device=devices.cpu).to(x.device)
46

47
    return torch.randn_like(x)
48

49

50
def randn_without_seed(shape, generator=None):
51
    """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
52

53
    Use either randn() or manual_seed() to initialize the generator."""
54

55
    if shared.opts.randn_source == "NV":
56
        return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
57

58
    if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
59
        return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
60

61
    return torch.randn(shape, device=devices.device, generator=generator)
62

63

64
def manual_seed(seed):
65
    """Set up a global random number generator using the specified seed."""
66

67
    if shared.opts.randn_source == "NV":
68
        global nv_rng
69
        nv_rng = rng_philox.Generator(seed)
70
        return
71

72
    torch.manual_seed(seed)
73

74

75
def create_generator(seed):
76
    if shared.opts.randn_source == "NV":
77
        return rng_philox.Generator(seed)
78

79
    device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
80
    generator = torch.Generator(device).manual_seed(int(seed))
81
    return generator
82

83

84
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
85
def slerp(val, low, high):
86
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
87
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
88
    dot = (low_norm*high_norm).sum(1)
89

90
    if dot.mean() > 0.9995:
91
        return low * val + high * (1 - val)
92

93
    omega = torch.acos(dot)
94
    so = torch.sin(omega)
95
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
96
    return res
97

98

99
class ImageRNG:
100
    def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
101
        self.shape = tuple(map(int, shape))
102
        self.seeds = seeds
103
        self.subseeds = subseeds
104
        self.subseed_strength = subseed_strength
105
        self.seed_resize_from_h = seed_resize_from_h
106
        self.seed_resize_from_w = seed_resize_from_w
107

108
        self.generators = [create_generator(seed) for seed in seeds]
109

110
        self.is_first = True
111

112
    def first(self):
113
        noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
114

115
        xs = []
116

117
        for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
118
            subnoise = None
119
            if self.subseeds is not None and self.subseed_strength != 0:
120
                subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
121
                subnoise = randn(subseed, noise_shape)
122

123
            if noise_shape != self.shape:
124
                noise = randn(seed, noise_shape)
125
            else:
126
                noise = randn(seed, self.shape, generator=generator)
127

128
            if subnoise is not None:
129
                noise = slerp(self.subseed_strength, noise, subnoise)
130

131
            if noise_shape != self.shape:
132
                x = randn(seed, self.shape, generator=generator)
133
                dx = (self.shape[2] - noise_shape[2]) // 2
134
                dy = (self.shape[1] - noise_shape[1]) // 2
135
                w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
136
                h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
137
                tx = 0 if dx < 0 else dx
138
                ty = 0 if dy < 0 else dy
139
                dx = max(-dx, 0)
140
                dy = max(-dy, 0)
141

142
                x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
143
                noise = x
144

145
            xs.append(noise)
146

147
        eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
148
        if eta_noise_seed_delta:
149
            self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
150

151
        return torch.stack(xs).to(shared.device)
152

153
    def next(self):
154
        if self.is_first:
155
            self.is_first = False
156
            return self.first()
157

158
        xs = []
159
        for generator in self.generators:
160
            x = randn_without_seed(self.shape, generator=generator)
161
            xs.append(x)
162

163
        return torch.stack(xs).to(shared.device)
164

165

166
devices.randn = randn
167
devices.randn_local = randn_local
168
devices.randn_like = randn_like
169
devices.randn_without_seed = randn_without_seed
170
devices.manual_seed = manual_seed
171

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.