stable-diffusion-webui

Форк
0
/
sd_samplers_cfg_denoiser.py 
298 строк · 13.5 Кб
1
import torch
2
from modules import prompt_parser, devices, sd_samplers_common
3

4
from modules.shared import opts, state
5
import modules.shared as shared
6
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
7
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
8
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
9

10

11
def catenate_conds(conds):
12
    if not isinstance(conds[0], dict):
13
        return torch.cat(conds)
14

15
    return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
16

17

18
def subscript_cond(cond, a, b):
19
    if not isinstance(cond, dict):
20
        return cond[a:b]
21

22
    return {key: vec[a:b] for key, vec in cond.items()}
23

24

25
def pad_cond(tensor, repeats, empty):
26
    if not isinstance(tensor, dict):
27
        return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
28

29
    tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
30
    return tensor
31

32

33
class CFGDenoiser(torch.nn.Module):
34
    """
35
    Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
36
    that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
37
    instead of one. Originally, the second prompt is just an empty string, but we use non-empty
38
    negative prompt.
39
    """
40

41
    def __init__(self, sampler):
42
        super().__init__()
43
        self.model_wrap = None
44
        self.mask = None
45
        self.nmask = None
46
        self.init_latent = None
47
        self.steps = None
48
        """number of steps as specified by user in UI"""
49

50
        self.total_steps = None
51
        """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
52

53
        self.step = 0
54
        self.image_cfg_scale = None
55
        self.padded_cond_uncond = False
56
        self.padded_cond_uncond_v0 = False
57
        self.sampler = sampler
58
        self.model_wrap = None
59
        self.p = None
60

61
        # NOTE: masking before denoising can cause the original latents to be oversmoothed
62
        # as the original latents do not have noise
63
        self.mask_before_denoising = False
64

65
    @property
66
    def inner_model(self):
67
        raise NotImplementedError()
68

69
    def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
70
        denoised_uncond = x_out[-uncond.shape[0]:]
71
        denoised = torch.clone(denoised_uncond)
72

73
        for i, conds in enumerate(conds_list):
74
            for cond_index, weight in conds:
75
                denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
76

77
        return denoised
78

79
    def combine_denoised_for_edit_model(self, x_out, cond_scale):
80
        out_cond, out_img_cond, out_uncond = x_out.chunk(3)
81
        denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
82

83
        return denoised
84

85
    def get_pred_x0(self, x_in, x_out, sigma):
86
        return x_out
87

88
    def update_inner_model(self):
89
        self.model_wrap = None
90

91
        c, uc = self.p.get_conds()
92
        self.sampler.sampler_extra_args['cond'] = c
93
        self.sampler.sampler_extra_args['uncond'] = uc
94

95
    def pad_cond_uncond(self, cond, uncond):
96
        empty = shared.sd_model.cond_stage_model_empty_prompt
97
        num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
98

99
        if num_repeats < 0:
100
            cond = pad_cond(cond, -num_repeats, empty)
101
            self.padded_cond_uncond = True
102
        elif num_repeats > 0:
103
            uncond = pad_cond(uncond, num_repeats, empty)
104
            self.padded_cond_uncond = True
105

106
        return cond, uncond
107

108
    def pad_cond_uncond_v0(self, cond, uncond):
109
        """
110
        Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
111

112
        If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
113
        If 'uncond' is a tensor, it is padded directly.
114

115
        If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
116
        is repeated to match the number of columns in 'cond'.
117

118
        If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
119
        to match the number of columns in 'cond'.
120

121
        Args:
122
            cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
123
            uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
124

125
        Returns:
126
            tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
127

128
        Note:
129
            This is the padding that was always used in DDIM before version 1.6.0
130
        """
131

132
        is_dict_cond = isinstance(uncond, dict)
133
        uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
134

135
        if uncond_vec.shape[1] < cond.shape[1]:
136
            last_vector = uncond_vec[:, -1:]
137
            last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
138
            uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
139
            self.padded_cond_uncond_v0 = True
140
        elif uncond_vec.shape[1] > cond.shape[1]:
141
            uncond_vec = uncond_vec[:, :cond.shape[1]]
142
            self.padded_cond_uncond_v0 = True
143

144
        if is_dict_cond:
145
            uncond['crossattn'] = uncond_vec
146
        else:
147
            uncond = uncond_vec
148

149
        return cond, uncond
150

151
    def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
152
        if state.interrupted or state.skipped:
153
            raise sd_samplers_common.InterruptedException
154

155
        if sd_samplers_common.apply_refiner(self):
156
            cond = self.sampler.sampler_extra_args['cond']
157
            uncond = self.sampler.sampler_extra_args['uncond']
158

159
        # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
160
        # so is_edit_model is set to False to support AND composition.
161
        is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
162

163
        conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
164
        uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
165

166
        assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
167

168
        # If we use masks, blending between the denoised and original latent images occurs here.
169
        def apply_blend(current_latent):
170
            blended_latent = current_latent * self.nmask + self.init_latent * self.mask
171

172
            if self.p.scripts is not None:
173
                from modules import scripts
174
                mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
175
                self.p.scripts.on_mask_blend(self.p, mba)
176
                blended_latent = mba.blended_latent
177

178
            return blended_latent
179

180
        # Blend in the original latents (before)
181
        if self.mask_before_denoising and self.mask is not None:
182
            x = apply_blend(x)
183

184
        batch_size = len(conds_list)
185
        repeats = [len(conds_list[i]) for i in range(batch_size)]
186

187
        if shared.sd_model.model.conditioning_key == "crossattn-adm":
188
            image_uncond = torch.zeros_like(image_cond)
189
            make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
190
        else:
191
            image_uncond = image_cond
192
            if isinstance(uncond, dict):
193
                make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
194
            else:
195
                make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
196

197
        if not is_edit_model:
198
            x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
199
            sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
200
            image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
201
        else:
202
            x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
203
            sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
204
            image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
205

206
        denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
207
        cfg_denoiser_callback(denoiser_params)
208
        x_in = denoiser_params.x
209
        image_cond_in = denoiser_params.image_cond
210
        sigma_in = denoiser_params.sigma
211
        tensor = denoiser_params.text_cond
212
        uncond = denoiser_params.text_uncond
213
        skip_uncond = False
214

215
        # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
216
        if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
217
            skip_uncond = True
218
            x_in = x_in[:-batch_size]
219
            sigma_in = sigma_in[:-batch_size]
220

221
        self.padded_cond_uncond = False
222
        self.padded_cond_uncond_v0 = False
223
        if shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
224
            tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
225
        elif shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
226
            tensor, uncond = self.pad_cond_uncond(tensor, uncond)
227

228
        if tensor.shape[1] == uncond.shape[1] or skip_uncond:
229
            if is_edit_model:
230
                cond_in = catenate_conds([tensor, uncond, uncond])
231
            elif skip_uncond:
232
                cond_in = tensor
233
            else:
234
                cond_in = catenate_conds([tensor, uncond])
235

236
            if shared.opts.batch_cond_uncond:
237
                x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
238
            else:
239
                x_out = torch.zeros_like(x_in)
240
                for batch_offset in range(0, x_out.shape[0], batch_size):
241
                    a = batch_offset
242
                    b = a + batch_size
243
                    x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
244
        else:
245
            x_out = torch.zeros_like(x_in)
246
            batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
247
            for batch_offset in range(0, tensor.shape[0], batch_size):
248
                a = batch_offset
249
                b = min(a + batch_size, tensor.shape[0])
250

251
                if not is_edit_model:
252
                    c_crossattn = subscript_cond(tensor, a, b)
253
                else:
254
                    c_crossattn = torch.cat([tensor[a:b]], uncond)
255

256
                x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
257

258
            if not skip_uncond:
259
                x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
260

261
        denoised_image_indexes = [x[0][0] for x in conds_list]
262
        if skip_uncond:
263
            fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
264
            x_out = torch.cat([x_out, fake_uncond])  # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
265

266
        denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
267
        cfg_denoised_callback(denoised_params)
268

269
        devices.test_for_nans(x_out, "unet")
270

271
        if is_edit_model:
272
            denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
273
        elif skip_uncond:
274
            denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
275
        else:
276
            denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
277

278
        # Blend in the original latents (after)
279
        if not self.mask_before_denoising and self.mask is not None:
280
            denoised = apply_blend(denoised)
281

282
        self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
283

284
        if opts.live_preview_content == "Prompt":
285
            preview = self.sampler.last_latent
286
        elif opts.live_preview_content == "Negative prompt":
287
            preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
288
        else:
289
            preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
290

291
        sd_samplers_common.store_latent(preview)
292

293
        after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
294
        cfg_after_cfg_callback(after_cfg_callback_params)
295
        denoised = after_cfg_callback_params.x
296

297
        self.step += 1
298
        return denoised
299

300

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.