stable-diffusion-webui

Форк
0
/
prompt_parser.py 
464 строки · 16.3 Кб
1
from __future__ import annotations
2

3
import re
4
from collections import namedtuple
5
import lark
6

7
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]"
8
# will be represented with prompt_schedule like this (assuming steps=100):
9
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
10
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
11
# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
12
# [75, 'fantasy landscape with a lake and an oak in background masterful']
13
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
14

15
schedule_parser = lark.Lark(r"""
16
!start: (prompt | /[][():]/+)*
17
prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
18
!emphasized: "(" prompt ")"
19
        | "(" prompt ":" prompt ")"
20
        | "[" prompt "]"
21
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
22
alternate: "[" prompt ("|" [prompt])+ "]"
23
WHITESPACE: /\s+/
24
plain: /([^\\\[\]():|]|\\.)+/
25
%import common.SIGNED_NUMBER -> NUMBER
26
""")
27

28
def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
29
    """
30
    >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
31
    >>> g("test")
32
    [[10, 'test']]
33
    >>> g("a [b:3]")
34
    [[3, 'a '], [10, 'a b']]
35
    >>> g("a [b: 3]")
36
    [[3, 'a '], [10, 'a b']]
37
    >>> g("a [[[b]]:2]")
38
    [[2, 'a '], [10, 'a [[b]]']]
39
    >>> g("[(a:2):3]")
40
    [[3, ''], [10, '(a:2)']]
41
    >>> g("a [b : c : 1] d")
42
    [[1, 'a b  d'], [10, 'a  c  d']]
43
    >>> g("a[b:[c:d:2]:1]e")
44
    [[1, 'abe'], [2, 'ace'], [10, 'ade']]
45
    >>> g("a [unbalanced")
46
    [[10, 'a [unbalanced']]
47
    >>> g("a [b:.5] c")
48
    [[5, 'a  c'], [10, 'a b c']]
49
    >>> g("a [{b|d{:.5] c")  # not handling this right now
50
    [[5, 'a  c'], [10, 'a {b|d{ c']]
51
    >>> g("((a][:b:c [d:3]")
52
    [[3, '((a][:b:c '], [10, '((a][:b:c d']]
53
    >>> g("[a|(b:1.1)]")
54
    [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
55
    >>> g("[fe|]male")
56
    [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
57
    >>> g("[fe|||]male")
58
    [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
59
    >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
60
    >>> g("a [b:.5] c")
61
    [[10, 'a b c']]
62
    >>> g("a [b:1.5] c")
63
    [[5, 'a  c'], [10, 'a b c']]
64
    """
65

66
    if hires_steps is None or use_old_scheduling:
67
        int_offset = 0
68
        flt_offset = 0
69
        steps = base_steps
70
    else:
71
        int_offset = base_steps
72
        flt_offset = 1.0
73
        steps = hires_steps
74

75
    def collect_steps(steps, tree):
76
        res = [steps]
77

78
        class CollectSteps(lark.Visitor):
79
            def scheduled(self, tree):
80
                s = tree.children[-2]
81
                v = float(s)
82
                if use_old_scheduling:
83
                    v = v*steps if v<1 else v
84
                else:
85
                    if "." in s:
86
                        v = (v - flt_offset) * steps
87
                    else:
88
                        v = (v - int_offset)
89
                tree.children[-2] = min(steps, int(v))
90
                if tree.children[-2] >= 1:
91
                    res.append(tree.children[-2])
92

93
            def alternate(self, tree):
94
                res.extend(range(1, steps+1))
95

96
        CollectSteps().visit(tree)
97
        return sorted(set(res))
98

99
    def at_step(step, tree):
100
        class AtStep(lark.Transformer):
101
            def scheduled(self, args):
102
                before, after, _, when, _ = args
103
                yield before or () if step <= when else after
104
            def alternate(self, args):
105
                args = ["" if not arg else arg for arg in args]
106
                yield args[(step - 1) % len(args)]
107
            def start(self, args):
108
                def flatten(x):
109
                    if isinstance(x, str):
110
                        yield x
111
                    else:
112
                        for gen in x:
113
                            yield from flatten(gen)
114
                return ''.join(flatten(args))
115
            def plain(self, args):
116
                yield args[0].value
117
            def __default__(self, data, children, meta):
118
                for child in children:
119
                    yield child
120
        return AtStep().transform(tree)
121

122
    def get_schedule(prompt):
123
        try:
124
            tree = schedule_parser.parse(prompt)
125
        except lark.exceptions.LarkError:
126
            if 0:
127
                import traceback
128
                traceback.print_exc()
129
            return [[steps, prompt]]
130
        return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
131

132
    promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
133
    return [promptdict[prompt] for prompt in prompts]
134

135

136
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
137

138

139
class SdConditioning(list):
140
    """
141
    A list with prompts for stable diffusion's conditioner model.
142
    Can also specify width and height of created image - SDXL needs it.
143
    """
144
    def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
145
        super().__init__()
146
        self.extend(prompts)
147

148
        if copy_from is None:
149
            copy_from = prompts
150

151
        self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
152
        self.width = width or getattr(copy_from, 'width', None)
153
        self.height = height or getattr(copy_from, 'height', None)
154

155

156

157
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
158
    """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
159
    and the sampling step at which this condition is to be replaced by the next one.
160

161
    Input:
162
    (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
163

164
    Output:
165
    [
166
        [
167
            ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886,  0.0229, -0.0523,  ..., -0.4901, -0.3066,  0.0674], ..., [ 0.3317, -0.5102, -0.4066,  ...,  0.4119, -0.7647, -1.0160]], device='cuda:0'))
168
        ],
169
        [
170
            ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886,  0.0229, -0.0522,  ..., -0.4901, -0.3067,  0.0673], ..., [-0.0192,  0.3867, -0.4644,  ...,  0.1135, -0.3696, -0.4625]], device='cuda:0')),
171
            ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886,  0.0229, -0.0522,  ..., -0.4901, -0.3067,  0.0673], ..., [-0.7352, -0.4356, -0.7888,  ...,  0.6994, -0.4312, -1.2593]], device='cuda:0'))
172
        ]
173
    ]
174
    """
175
    res = []
176

177
    prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
178
    cache = {}
179

180
    for prompt, prompt_schedule in zip(prompts, prompt_schedules):
181

182
        cached = cache.get(prompt, None)
183
        if cached is not None:
184
            res.append(cached)
185
            continue
186

187
        texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
188
        conds = model.get_learned_conditioning(texts)
189

190
        cond_schedule = []
191
        for i, (end_at_step, _) in enumerate(prompt_schedule):
192
            if isinstance(conds, dict):
193
                cond = {k: v[i] for k, v in conds.items()}
194
            else:
195
                cond = conds[i]
196

197
            cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
198

199
        cache[prompt] = cond_schedule
200
        res.append(cond_schedule)
201

202
    return res
203

204

205
re_AND = re.compile(r"\bAND\b")
206
re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
207

208

209
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
210
    res_indexes = []
211

212
    prompt_indexes = {}
213
    prompt_flat_list = SdConditioning(prompts)
214
    prompt_flat_list.clear()
215

216
    for prompt in prompts:
217
        subprompts = re_AND.split(prompt)
218

219
        indexes = []
220
        for subprompt in subprompts:
221
            match = re_weight.search(subprompt)
222

223
            text, weight = match.groups() if match is not None else (subprompt, 1.0)
224

225
            weight = float(weight) if weight is not None else 1.0
226

227
            index = prompt_indexes.get(text, None)
228
            if index is None:
229
                index = len(prompt_flat_list)
230
                prompt_flat_list.append(text)
231
                prompt_indexes[text] = index
232

233
            indexes.append((index, weight))
234

235
        res_indexes.append(indexes)
236

237
    return res_indexes, prompt_flat_list, prompt_indexes
238

239

240
class ComposableScheduledPromptConditioning:
241
    def __init__(self, schedules, weight=1.0):
242
        self.schedules: list[ScheduledPromptConditioning] = schedules
243
        self.weight: float = weight
244

245

246
class MulticondLearnedConditioning:
247
    def __init__(self, shape, batch):
248
        self.shape: tuple = shape  # the shape field is needed to send this object to DDIM/PLMS
249
        self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
250

251

252
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
253
    """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
254
    For each prompt, the list is obtained by splitting the prompt using the AND separator.
255

256
    https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
257
    """
258

259
    res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
260

261
    learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
262

263
    res = []
264
    for indexes in res_indexes:
265
        res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
266

267
    return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
268

269

270
class DictWithShape(dict):
271
    def __init__(self, x, shape):
272
        super().__init__()
273
        self.update(x)
274

275
    @property
276
    def shape(self):
277
        return self["crossattn"].shape
278

279

280
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
281
    param = c[0][0].cond
282
    is_dict = isinstance(param, dict)
283

284
    if is_dict:
285
        dict_cond = param
286
        res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
287
        res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
288
    else:
289
        res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
290

291
    for i, cond_schedule in enumerate(c):
292
        target_index = 0
293
        for current, entry in enumerate(cond_schedule):
294
            if current_step <= entry.end_at_step:
295
                target_index = current
296
                break
297

298
        if is_dict:
299
            for k, param in cond_schedule[target_index].cond.items():
300
                res[k][i] = param
301
        else:
302
            res[i] = cond_schedule[target_index].cond
303

304
    return res
305

306

307
def stack_conds(tensors):
308
    # if prompts have wildly different lengths above the limit we'll get tensors of different shapes
309
    # and won't be able to torch.stack them. So this fixes that.
310
    token_count = max([x.shape[0] for x in tensors])
311
    for i in range(len(tensors)):
312
        if tensors[i].shape[0] != token_count:
313
            last_vector = tensors[i][-1:]
314
            last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
315
            tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
316

317
    return torch.stack(tensors)
318

319

320

321
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
322
    param = c.batch[0][0].schedules[0].cond
323

324
    tensors = []
325
    conds_list = []
326

327
    for composable_prompts in c.batch:
328
        conds_for_batch = []
329

330
        for composable_prompt in composable_prompts:
331
            target_index = 0
332
            for current, entry in enumerate(composable_prompt.schedules):
333
                if current_step <= entry.end_at_step:
334
                    target_index = current
335
                    break
336

337
            conds_for_batch.append((len(tensors), composable_prompt.weight))
338
            tensors.append(composable_prompt.schedules[target_index].cond)
339

340
        conds_list.append(conds_for_batch)
341

342
    if isinstance(tensors[0], dict):
343
        keys = list(tensors[0].keys())
344
        stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
345
        stacked = DictWithShape(stacked, stacked['crossattn'].shape)
346
    else:
347
        stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
348

349
    return conds_list, stacked
350

351

352
re_attention = re.compile(r"""
353
\\\(|
354
\\\)|
355
\\\[|
356
\\]|
357
\\\\|
358
\\|
359
\(|
360
\[|
361
:\s*([+-]?[.\d]+)\s*\)|
362
\)|
363
]|
364
[^\\()\[\]:]+|
365
:
366
""", re.X)
367

368
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
369

370
def parse_prompt_attention(text):
371
    """
372
    Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
373
    Accepted tokens are:
374
      (abc) - increases attention to abc by a multiplier of 1.1
375
      (abc:3.12) - increases attention to abc by a multiplier of 3.12
376
      [abc] - decreases attention to abc by a multiplier of 1.1
377
      \( - literal character '('
378
      \[ - literal character '['
379
      \) - literal character ')'
380
      \] - literal character ']'
381
      \\ - literal character '\'
382
      anything else - just text
383

384
    >>> parse_prompt_attention('normal text')
385
    [['normal text', 1.0]]
386
    >>> parse_prompt_attention('an (important) word')
387
    [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
388
    >>> parse_prompt_attention('(unbalanced')
389
    [['unbalanced', 1.1]]
390
    >>> parse_prompt_attention('\(literal\]')
391
    [['(literal]', 1.0]]
392
    >>> parse_prompt_attention('(unnecessary)(parens)')
393
    [['unnecessaryparens', 1.1]]
394
    >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
395
    [['a ', 1.0],
396
     ['house', 1.5730000000000004],
397
     [' ', 1.1],
398
     ['on', 1.0],
399
     [' a ', 1.1],
400
     ['hill', 0.55],
401
     [', sun, ', 1.1],
402
     ['sky', 1.4641000000000006],
403
     ['.', 1.1]]
404
    """
405

406
    res = []
407
    round_brackets = []
408
    square_brackets = []
409

410
    round_bracket_multiplier = 1.1
411
    square_bracket_multiplier = 1 / 1.1
412

413
    def multiply_range(start_position, multiplier):
414
        for p in range(start_position, len(res)):
415
            res[p][1] *= multiplier
416

417
    for m in re_attention.finditer(text):
418
        text = m.group(0)
419
        weight = m.group(1)
420

421
        if text.startswith('\\'):
422
            res.append([text[1:], 1.0])
423
        elif text == '(':
424
            round_brackets.append(len(res))
425
        elif text == '[':
426
            square_brackets.append(len(res))
427
        elif weight is not None and round_brackets:
428
            multiply_range(round_brackets.pop(), float(weight))
429
        elif text == ')' and round_brackets:
430
            multiply_range(round_brackets.pop(), round_bracket_multiplier)
431
        elif text == ']' and square_brackets:
432
            multiply_range(square_brackets.pop(), square_bracket_multiplier)
433
        else:
434
            parts = re.split(re_break, text)
435
            for i, part in enumerate(parts):
436
                if i > 0:
437
                    res.append(["BREAK", -1])
438
                res.append([part, 1.0])
439

440
    for pos in round_brackets:
441
        multiply_range(pos, round_bracket_multiplier)
442

443
    for pos in square_brackets:
444
        multiply_range(pos, square_bracket_multiplier)
445

446
    if len(res) == 0:
447
        res = [["", 1.0]]
448

449
    # merge runs of identical weights
450
    i = 0
451
    while i + 1 < len(res):
452
        if res[i][1] == res[i + 1][1]:
453
            res[i][0] += res[i + 1][0]
454
            res.pop(i + 1)
455
        else:
456
            i += 1
457

458
    return res
459

460
if __name__ == "__main__":
461
    import doctest
462
    doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
463
else:
464
    import torch  # doctest faster
465

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.