stable-diffusion-webui
361 строка · 14.7 Кб
1import math2from collections import namedtuple3
4import torch5
6from modules import prompt_parser, devices, sd_hijack, sd_emphasis7from modules.shared import opts8
9
10class PromptChunk:11"""12This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
13If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
14Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
15so just 75 tokens from prompt.
16"""
17
18def __init__(self):19self.tokens = []20self.multipliers = []21self.fixes = []22
23
24PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])25"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
26chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
27are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
28
29
30class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):31"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to32have unlimited prompt length and assign weights to tokens in prompt.
33"""
34
35def __init__(self, wrapped, hijack):36super().__init__()37
38self.wrapped = wrapped39"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,40depending on model."""
41
42self.hijack: sd_hijack.StableDiffusionModelHijack = hijack43self.chunk_length = 7544
45self.is_trainable = getattr(wrapped, 'is_trainable', False)46self.input_key = getattr(wrapped, 'input_key', 'txt')47self.legacy_ucg_val = None48
49def empty_chunk(self):50"""creates an empty PromptChunk and returns it"""51
52chunk = PromptChunk()53chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)54chunk.multipliers = [1.0] * (self.chunk_length + 2)55return chunk56
57def get_target_prompt_token_count(self, token_count):58"""returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""59
60return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length61
62def tokenize(self, texts):63"""Converts a batch of texts into a batch of token ids"""64
65raise NotImplementedError66
67def encode_with_transformers(self, tokens):68"""69converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
70All python lists with tokens are assumed to have same length, usually 77.
71if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
72model - can be 768 and 1024.
73Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
74"""
75
76raise NotImplementedError77
78def encode_embedding_init_text(self, init_text, nvpt):79"""Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through80transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
81
82raise NotImplementedError83
84def tokenize_line(self, line):85"""86this transforms a single prompt into a list of PromptChunk objects - as many as needed to
87represent the prompt.
88Returns the list and the total number of tokens in the prompt.
89"""
90
91if opts.emphasis != "None":92parsed = prompt_parser.parse_prompt_attention(line)93else:94parsed = [[line, 1.0]]95
96tokenized = self.tokenize([text for text, _ in parsed])97
98chunks = []99chunk = PromptChunk()100token_count = 0101last_comma = -1102
103def next_chunk(is_last=False):104"""puts current chunk into the list of results and produces the next one - empty;105if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
106nonlocal token_count107nonlocal last_comma108nonlocal chunk109
110if is_last:111token_count += len(chunk.tokens)112else:113token_count += self.chunk_length114
115to_add = self.chunk_length - len(chunk.tokens)116if to_add > 0:117chunk.tokens += [self.id_end] * to_add118chunk.multipliers += [1.0] * to_add119
120chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]121chunk.multipliers = [1.0] + chunk.multipliers + [1.0]122
123last_comma = -1124chunks.append(chunk)125chunk = PromptChunk()126
127for tokens, (text, weight) in zip(tokenized, parsed):128if text == 'BREAK' and weight == -1:129next_chunk()130continue131
132position = 0133while position < len(tokens):134token = tokens[position]135
136if token == self.comma_token:137last_comma = len(chunk.tokens)138
139# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack140# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.141elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:142break_location = last_comma + 1143
144reloc_tokens = chunk.tokens[break_location:]145reloc_mults = chunk.multipliers[break_location:]146
147chunk.tokens = chunk.tokens[:break_location]148chunk.multipliers = chunk.multipliers[:break_location]149
150next_chunk()151chunk.tokens = reloc_tokens152chunk.multipliers = reloc_mults153
154if len(chunk.tokens) == self.chunk_length:155next_chunk()156
157embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)158if embedding is None:159chunk.tokens.append(token)160chunk.multipliers.append(weight)161position += 1162continue163
164emb_len = int(embedding.vectors)165if len(chunk.tokens) + emb_len > self.chunk_length:166next_chunk()167
168chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))169
170chunk.tokens += [0] * emb_len171chunk.multipliers += [weight] * emb_len172position += embedding_length_in_tokens173
174if chunk.tokens or not chunks:175next_chunk(is_last=True)176
177return chunks, token_count178
179def process_texts(self, texts):180"""181Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
182length, in tokens, of all texts.
183"""
184
185token_count = 0186
187cache = {}188batch_chunks = []189for line in texts:190if line in cache:191chunks = cache[line]192else:193chunks, current_token_count = self.tokenize_line(line)194token_count = max(current_token_count, token_count)195
196cache[line] = chunks197
198batch_chunks.append(chunks)199
200return batch_chunks, token_count201
202def forward(self, texts):203"""204Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
205Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
206be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
207An example shape returned by this function can be: (2, 77, 768).
208For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
209Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
210is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
211"""
212
213if opts.use_old_emphasis_implementation:214import modules.sd_hijack_clip_old215return modules.sd_hijack_clip_old.forward_old(self, texts)216
217batch_chunks, token_count = self.process_texts(texts)218
219used_embeddings = {}220chunk_count = max([len(x) for x in batch_chunks])221
222zs = []223for i in range(chunk_count):224batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]225
226tokens = [x.tokens for x in batch_chunk]227multipliers = [x.multipliers for x in batch_chunk]228self.hijack.fixes = [x.fixes for x in batch_chunk]229
230for fixes in self.hijack.fixes:231for _position, embedding in fixes:232used_embeddings[embedding.name] = embedding233
234z = self.process_tokens(tokens, multipliers)235zs.append(z)236
237if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:238hashes = []239for name, embedding in used_embeddings.items():240shorthash = embedding.shorthash241if not shorthash:242continue243
244name = name.replace(":", "").replace(",", "")245hashes.append(f"{name}: {shorthash}")246
247if hashes:248if self.hijack.extra_generation_params.get("TI hashes"):249hashes.append(self.hijack.extra_generation_params.get("TI hashes"))250self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)251
252if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":253self.hijack.extra_generation_params["Emphasis"] = opts.emphasis254
255if getattr(self.wrapped, 'return_pooled', False):256return torch.hstack(zs), zs[0].pooled257else:258return torch.hstack(zs)259
260def process_tokens(self, remade_batch_tokens, batch_multipliers):261"""262sends one single prompt chunk to be encoded by transformers neural network.
263remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
264there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
265Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
266corresponds to one token.
267"""
268tokens = torch.asarray(remade_batch_tokens).to(devices.device)269
270# this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.271if self.id_end != self.id_pad:272for batch_pos in range(len(remade_batch_tokens)):273index = remade_batch_tokens[batch_pos].index(self.id_end)274tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad275
276z = self.encode_with_transformers(tokens)277
278pooled = getattr(z, 'pooled', None)279
280emphasis = sd_emphasis.get_current_option(opts.emphasis)()281emphasis.tokens = remade_batch_tokens282emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device)283emphasis.z = z284
285emphasis.after_transformers()286
287z = emphasis.z288
289if pooled is not None:290z.pooled = pooled291
292return z293
294
295class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):296def __init__(self, wrapped, hijack):297super().__init__(wrapped, hijack)298self.tokenizer = wrapped.tokenizer299
300vocab = self.tokenizer.get_vocab()301
302self.comma_token = vocab.get(',</w>', None)303
304self.token_mults = {}305tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]306for text, ident in tokens_with_parens:307mult = 1.0308for c in text:309if c == '[':310mult /= 1.1311if c == ']':312mult *= 1.1313if c == '(':314mult *= 1.1315if c == ')':316mult /= 1.1317
318if mult != 1.0:319self.token_mults[ident] = mult320
321self.id_start = self.wrapped.tokenizer.bos_token_id322self.id_end = self.wrapped.tokenizer.eos_token_id323self.id_pad = self.id_end324
325def tokenize(self, texts):326tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]327
328return tokenized329
330def encode_with_transformers(self, tokens):331outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)332
333if opts.CLIP_stop_at_last_layers > 1:334z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]335z = self.wrapped.transformer.text_model.final_layer_norm(z)336else:337z = outputs.last_hidden_state338
339return z340
341def encode_embedding_init_text(self, init_text, nvpt):342embedding_layer = self.wrapped.transformer.text_model.embeddings343ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]344embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)345
346return embedded347
348
349class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):350def __init__(self, wrapped, hijack):351super().__init__(wrapped, hijack)352
353def encode_with_transformers(self, tokens):354outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")355
356if self.wrapped.layer == "last":357z = outputs.last_hidden_state358else:359z = outputs.hidden_states[self.wrapped.layer_idx]360
361return z362