stable-diffusion-webui
464 строки · 16.3 Кб
1from __future__ import annotations2
3import re4from collections import namedtuple5import lark6
7# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]"
8# will be represented with prompt_schedule like this (assuming steps=100):
9# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
10# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
11# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
12# [75, 'fantasy landscape with a lake and an oak in background masterful']
13# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
14
15schedule_parser = lark.Lark(r"""16!start: (prompt | /[][():]/+)*
17prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
18!emphasized: "(" prompt ")"
19| "(" prompt ":" prompt ")"
20| "[" prompt "]"
21scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
22alternate: "[" prompt ("|" [prompt])+ "]"
23WHITESPACE: /\s+/
24plain: /([^\\\[\]():|]|\\.)+/
25%import common.SIGNED_NUMBER -> NUMBER
26""")27
28def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):29"""30>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
31>>> g("test")
32[[10, 'test']]
33>>> g("a [b:3]")
34[[3, 'a '], [10, 'a b']]
35>>> g("a [b: 3]")
36[[3, 'a '], [10, 'a b']]
37>>> g("a [[[b]]:2]")
38[[2, 'a '], [10, 'a [[b]]']]
39>>> g("[(a:2):3]")
40[[3, ''], [10, '(a:2)']]
41>>> g("a [b : c : 1] d")
42[[1, 'a b d'], [10, 'a c d']]
43>>> g("a[b:[c:d:2]:1]e")
44[[1, 'abe'], [2, 'ace'], [10, 'ade']]
45>>> g("a [unbalanced")
46[[10, 'a [unbalanced']]
47>>> g("a [b:.5] c")
48[[5, 'a c'], [10, 'a b c']]
49>>> g("a [{b|d{:.5] c") # not handling this right now
50[[5, 'a c'], [10, 'a {b|d{ c']]
51>>> g("((a][:b:c [d:3]")
52[[3, '((a][:b:c '], [10, '((a][:b:c d']]
53>>> g("[a|(b:1.1)]")
54[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
55>>> g("[fe|]male")
56[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
57>>> g("[fe|||]male")
58[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
59>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
60>>> g("a [b:.5] c")
61[[10, 'a b c']]
62>>> g("a [b:1.5] c")
63[[5, 'a c'], [10, 'a b c']]
64"""
65
66if hires_steps is None or use_old_scheduling:67int_offset = 068flt_offset = 069steps = base_steps70else:71int_offset = base_steps72flt_offset = 1.073steps = hires_steps74
75def collect_steps(steps, tree):76res = [steps]77
78class CollectSteps(lark.Visitor):79def scheduled(self, tree):80s = tree.children[-2]81v = float(s)82if use_old_scheduling:83v = v*steps if v<1 else v84else:85if "." in s:86v = (v - flt_offset) * steps87else:88v = (v - int_offset)89tree.children[-2] = min(steps, int(v))90if tree.children[-2] >= 1:91res.append(tree.children[-2])92
93def alternate(self, tree):94res.extend(range(1, steps+1))95
96CollectSteps().visit(tree)97return sorted(set(res))98
99def at_step(step, tree):100class AtStep(lark.Transformer):101def scheduled(self, args):102before, after, _, when, _ = args103yield before or () if step <= when else after104def alternate(self, args):105args = ["" if not arg else arg for arg in args]106yield args[(step - 1) % len(args)]107def start(self, args):108def flatten(x):109if isinstance(x, str):110yield x111else:112for gen in x:113yield from flatten(gen)114return ''.join(flatten(args))115def plain(self, args):116yield args[0].value117def __default__(self, data, children, meta):118for child in children:119yield child120return AtStep().transform(tree)121
122def get_schedule(prompt):123try:124tree = schedule_parser.parse(prompt)125except lark.exceptions.LarkError:126if 0:127import traceback128traceback.print_exc()129return [[steps, prompt]]130return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]131
132promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}133return [promptdict[prompt] for prompt in prompts]134
135
136ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])137
138
139class SdConditioning(list):140"""141A list with prompts for stable diffusion's conditioner model.
142Can also specify width and height of created image - SDXL needs it.
143"""
144def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):145super().__init__()146self.extend(prompts)147
148if copy_from is None:149copy_from = prompts150
151self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)152self.width = width or getattr(copy_from, 'width', None)153self.height = height or getattr(copy_from, 'height', None)154
155
156
157def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):158"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),159and the sampling step at which this condition is to be replaced by the next one.
160
161Input:
162(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
163
164Output:
165[
166[
167ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
168],
169[
170ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
171ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
172]
173]
174"""
175res = []176
177prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)178cache = {}179
180for prompt, prompt_schedule in zip(prompts, prompt_schedules):181
182cached = cache.get(prompt, None)183if cached is not None:184res.append(cached)185continue186
187texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)188conds = model.get_learned_conditioning(texts)189
190cond_schedule = []191for i, (end_at_step, _) in enumerate(prompt_schedule):192if isinstance(conds, dict):193cond = {k: v[i] for k, v in conds.items()}194else:195cond = conds[i]196
197cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))198
199cache[prompt] = cond_schedule200res.append(cond_schedule)201
202return res203
204
205re_AND = re.compile(r"\bAND\b")206re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")207
208
209def get_multicond_prompt_list(prompts: SdConditioning | list[str]):210res_indexes = []211
212prompt_indexes = {}213prompt_flat_list = SdConditioning(prompts)214prompt_flat_list.clear()215
216for prompt in prompts:217subprompts = re_AND.split(prompt)218
219indexes = []220for subprompt in subprompts:221match = re_weight.search(subprompt)222
223text, weight = match.groups() if match is not None else (subprompt, 1.0)224
225weight = float(weight) if weight is not None else 1.0226
227index = prompt_indexes.get(text, None)228if index is None:229index = len(prompt_flat_list)230prompt_flat_list.append(text)231prompt_indexes[text] = index232
233indexes.append((index, weight))234
235res_indexes.append(indexes)236
237return res_indexes, prompt_flat_list, prompt_indexes238
239
240class ComposableScheduledPromptConditioning:241def __init__(self, schedules, weight=1.0):242self.schedules: list[ScheduledPromptConditioning] = schedules243self.weight: float = weight244
245
246class MulticondLearnedConditioning:247def __init__(self, shape, batch):248self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS249self.batch: list[list[ComposableScheduledPromptConditioning]] = batch250
251
252def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:253"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.254For each prompt, the list is obtained by splitting the prompt using the AND separator.
255
256https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
257"""
258
259res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)260
261learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)262
263res = []264for indexes in res_indexes:265res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])266
267return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)268
269
270class DictWithShape(dict):271def __init__(self, x, shape):272super().__init__()273self.update(x)274
275@property276def shape(self):277return self["crossattn"].shape278
279
280def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):281param = c[0][0].cond282is_dict = isinstance(param, dict)283
284if is_dict:285dict_cond = param286res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}287res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)288else:289res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)290
291for i, cond_schedule in enumerate(c):292target_index = 0293for current, entry in enumerate(cond_schedule):294if current_step <= entry.end_at_step:295target_index = current296break297
298if is_dict:299for k, param in cond_schedule[target_index].cond.items():300res[k][i] = param301else:302res[i] = cond_schedule[target_index].cond303
304return res305
306
307def stack_conds(tensors):308# if prompts have wildly different lengths above the limit we'll get tensors of different shapes309# and won't be able to torch.stack them. So this fixes that.310token_count = max([x.shape[0] for x in tensors])311for i in range(len(tensors)):312if tensors[i].shape[0] != token_count:313last_vector = tensors[i][-1:]314last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])315tensors[i] = torch.vstack([tensors[i], last_vector_repeated])316
317return torch.stack(tensors)318
319
320
321def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):322param = c.batch[0][0].schedules[0].cond323
324tensors = []325conds_list = []326
327for composable_prompts in c.batch:328conds_for_batch = []329
330for composable_prompt in composable_prompts:331target_index = 0332for current, entry in enumerate(composable_prompt.schedules):333if current_step <= entry.end_at_step:334target_index = current335break336
337conds_for_batch.append((len(tensors), composable_prompt.weight))338tensors.append(composable_prompt.schedules[target_index].cond)339
340conds_list.append(conds_for_batch)341
342if isinstance(tensors[0], dict):343keys = list(tensors[0].keys())344stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}345stacked = DictWithShape(stacked, stacked['crossattn'].shape)346else:347stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)348
349return conds_list, stacked350
351
352re_attention = re.compile(r"""353\\\(|
354\\\)|
355\\\[|
356\\]|
357\\\\|
358\\|
359\(|
360\[|
361:\s*([+-]?[.\d]+)\s*\)|
362\)|
363]|
364[^\\()\[\]:]+|
365:
366""", re.X)367
368re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)369
370def parse_prompt_attention(text):371"""372Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
373Accepted tokens are:
374(abc) - increases attention to abc by a multiplier of 1.1
375(abc:3.12) - increases attention to abc by a multiplier of 3.12
376[abc] - decreases attention to abc by a multiplier of 1.1
377\( - literal character '('
378\[ - literal character '['
379\) - literal character ')'
380\] - literal character ']'
381\\ - literal character '\'
382anything else - just text
383
384>>> parse_prompt_attention('normal text')
385[['normal text', 1.0]]
386>>> parse_prompt_attention('an (important) word')
387[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
388>>> parse_prompt_attention('(unbalanced')
389[['unbalanced', 1.1]]
390>>> parse_prompt_attention('\(literal\]')
391[['(literal]', 1.0]]
392>>> parse_prompt_attention('(unnecessary)(parens)')
393[['unnecessaryparens', 1.1]]
394>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
395[['a ', 1.0],
396['house', 1.5730000000000004],
397[' ', 1.1],
398['on', 1.0],
399[' a ', 1.1],
400['hill', 0.55],
401[', sun, ', 1.1],
402['sky', 1.4641000000000006],
403['.', 1.1]]
404"""
405
406res = []407round_brackets = []408square_brackets = []409
410round_bracket_multiplier = 1.1411square_bracket_multiplier = 1 / 1.1412
413def multiply_range(start_position, multiplier):414for p in range(start_position, len(res)):415res[p][1] *= multiplier416
417for m in re_attention.finditer(text):418text = m.group(0)419weight = m.group(1)420
421if text.startswith('\\'):422res.append([text[1:], 1.0])423elif text == '(':424round_brackets.append(len(res))425elif text == '[':426square_brackets.append(len(res))427elif weight is not None and round_brackets:428multiply_range(round_brackets.pop(), float(weight))429elif text == ')' and round_brackets:430multiply_range(round_brackets.pop(), round_bracket_multiplier)431elif text == ']' and square_brackets:432multiply_range(square_brackets.pop(), square_bracket_multiplier)433else:434parts = re.split(re_break, text)435for i, part in enumerate(parts):436if i > 0:437res.append(["BREAK", -1])438res.append([part, 1.0])439
440for pos in round_brackets:441multiply_range(pos, round_bracket_multiplier)442
443for pos in square_brackets:444multiply_range(pos, square_bracket_multiplier)445
446if len(res) == 0:447res = [["", 1.0]]448
449# merge runs of identical weights450i = 0451while i + 1 < len(res):452if res[i][1] == res[i + 1][1]:453res[i][0] += res[i + 1][0]454res.pop(i + 1)455else:456i += 1457
458return res459
460if __name__ == "__main__":461import doctest462doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)463else:464import torch # doctest faster465