stable-diffusion-webui

Форк
0
/
sd_hijack_clip.py 
361 строка · 14.7 Кб
1
import math
2
from collections import namedtuple
3

4
import torch
5

6
from modules import prompt_parser, devices, sd_hijack, sd_emphasis
7
from modules.shared import opts
8

9

10
class PromptChunk:
11
    """
12
    This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
13
    If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
14
    Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
15
    so just 75 tokens from prompt.
16
    """
17

18
    def __init__(self):
19
        self.tokens = []
20
        self.multipliers = []
21
        self.fixes = []
22

23

24
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
25
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
26
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
27
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
28

29

30
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
31
    """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
32
    have unlimited prompt length and assign weights to tokens in prompt.
33
    """
34

35
    def __init__(self, wrapped, hijack):
36
        super().__init__()
37

38
        self.wrapped = wrapped
39
        """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
40
        depending on model."""
41

42
        self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
43
        self.chunk_length = 75
44

45
        self.is_trainable = getattr(wrapped, 'is_trainable', False)
46
        self.input_key = getattr(wrapped, 'input_key', 'txt')
47
        self.legacy_ucg_val = None
48

49
    def empty_chunk(self):
50
        """creates an empty PromptChunk and returns it"""
51

52
        chunk = PromptChunk()
53
        chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
54
        chunk.multipliers = [1.0] * (self.chunk_length + 2)
55
        return chunk
56

57
    def get_target_prompt_token_count(self, token_count):
58
        """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
59

60
        return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
61

62
    def tokenize(self, texts):
63
        """Converts a batch of texts into a batch of token ids"""
64

65
        raise NotImplementedError
66

67
    def encode_with_transformers(self, tokens):
68
        """
69
        converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
70
        All python lists with tokens are assumed to have same length, usually 77.
71
        if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
72
        model - can be 768 and 1024.
73
        Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
74
        """
75

76
        raise NotImplementedError
77

78
    def encode_embedding_init_text(self, init_text, nvpt):
79
        """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
80
        transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
81

82
        raise NotImplementedError
83

84
    def tokenize_line(self, line):
85
        """
86
        this transforms a single prompt into a list of PromptChunk objects - as many as needed to
87
        represent the prompt.
88
        Returns the list and the total number of tokens in the prompt.
89
        """
90

91
        if opts.emphasis != "None":
92
            parsed = prompt_parser.parse_prompt_attention(line)
93
        else:
94
            parsed = [[line, 1.0]]
95

96
        tokenized = self.tokenize([text for text, _ in parsed])
97

98
        chunks = []
99
        chunk = PromptChunk()
100
        token_count = 0
101
        last_comma = -1
102

103
        def next_chunk(is_last=False):
104
            """puts current chunk into the list of results and produces the next one - empty;
105
            if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
106
            nonlocal token_count
107
            nonlocal last_comma
108
            nonlocal chunk
109

110
            if is_last:
111
                token_count += len(chunk.tokens)
112
            else:
113
                token_count += self.chunk_length
114

115
            to_add = self.chunk_length - len(chunk.tokens)
116
            if to_add > 0:
117
                chunk.tokens += [self.id_end] * to_add
118
                chunk.multipliers += [1.0] * to_add
119

120
            chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
121
            chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
122

123
            last_comma = -1
124
            chunks.append(chunk)
125
            chunk = PromptChunk()
126

127
        for tokens, (text, weight) in zip(tokenized, parsed):
128
            if text == 'BREAK' and weight == -1:
129
                next_chunk()
130
                continue
131

132
            position = 0
133
            while position < len(tokens):
134
                token = tokens[position]
135

136
                if token == self.comma_token:
137
                    last_comma = len(chunk.tokens)
138

139
                # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
140
                # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
141
                elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
142
                    break_location = last_comma + 1
143

144
                    reloc_tokens = chunk.tokens[break_location:]
145
                    reloc_mults = chunk.multipliers[break_location:]
146

147
                    chunk.tokens = chunk.tokens[:break_location]
148
                    chunk.multipliers = chunk.multipliers[:break_location]
149

150
                    next_chunk()
151
                    chunk.tokens = reloc_tokens
152
                    chunk.multipliers = reloc_mults
153

154
                if len(chunk.tokens) == self.chunk_length:
155
                    next_chunk()
156

157
                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
158
                if embedding is None:
159
                    chunk.tokens.append(token)
160
                    chunk.multipliers.append(weight)
161
                    position += 1
162
                    continue
163

164
                emb_len = int(embedding.vectors)
165
                if len(chunk.tokens) + emb_len > self.chunk_length:
166
                    next_chunk()
167

168
                chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
169

170
                chunk.tokens += [0] * emb_len
171
                chunk.multipliers += [weight] * emb_len
172
                position += embedding_length_in_tokens
173

174
        if chunk.tokens or not chunks:
175
            next_chunk(is_last=True)
176

177
        return chunks, token_count
178

179
    def process_texts(self, texts):
180
        """
181
        Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
182
        length, in tokens, of all texts.
183
        """
184

185
        token_count = 0
186

187
        cache = {}
188
        batch_chunks = []
189
        for line in texts:
190
            if line in cache:
191
                chunks = cache[line]
192
            else:
193
                chunks, current_token_count = self.tokenize_line(line)
194
                token_count = max(current_token_count, token_count)
195

196
                cache[line] = chunks
197

198
            batch_chunks.append(chunks)
199

200
        return batch_chunks, token_count
201

202
    def forward(self, texts):
203
        """
204
        Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
205
        Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
206
        be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
207
        An example shape returned by this function can be: (2, 77, 768).
208
        For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
209
        Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
210
        is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
211
        """
212

213
        if opts.use_old_emphasis_implementation:
214
            import modules.sd_hijack_clip_old
215
            return modules.sd_hijack_clip_old.forward_old(self, texts)
216

217
        batch_chunks, token_count = self.process_texts(texts)
218

219
        used_embeddings = {}
220
        chunk_count = max([len(x) for x in batch_chunks])
221

222
        zs = []
223
        for i in range(chunk_count):
224
            batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
225

226
            tokens = [x.tokens for x in batch_chunk]
227
            multipliers = [x.multipliers for x in batch_chunk]
228
            self.hijack.fixes = [x.fixes for x in batch_chunk]
229

230
            for fixes in self.hijack.fixes:
231
                for _position, embedding in fixes:
232
                    used_embeddings[embedding.name] = embedding
233

234
            z = self.process_tokens(tokens, multipliers)
235
            zs.append(z)
236

237
        if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
238
            hashes = []
239
            for name, embedding in used_embeddings.items():
240
                shorthash = embedding.shorthash
241
                if not shorthash:
242
                    continue
243

244
                name = name.replace(":", "").replace(",", "")
245
                hashes.append(f"{name}: {shorthash}")
246

247
            if hashes:
248
                if self.hijack.extra_generation_params.get("TI hashes"):
249
                    hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
250
                self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
251

252
        if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
253
            self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
254

255
        if getattr(self.wrapped, 'return_pooled', False):
256
            return torch.hstack(zs), zs[0].pooled
257
        else:
258
            return torch.hstack(zs)
259

260
    def process_tokens(self, remade_batch_tokens, batch_multipliers):
261
        """
262
        sends one single prompt chunk to be encoded by transformers neural network.
263
        remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
264
        there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
265
        Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
266
        corresponds to one token.
267
        """
268
        tokens = torch.asarray(remade_batch_tokens).to(devices.device)
269

270
        # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
271
        if self.id_end != self.id_pad:
272
            for batch_pos in range(len(remade_batch_tokens)):
273
                index = remade_batch_tokens[batch_pos].index(self.id_end)
274
                tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
275

276
        z = self.encode_with_transformers(tokens)
277

278
        pooled = getattr(z, 'pooled', None)
279

280
        emphasis = sd_emphasis.get_current_option(opts.emphasis)()
281
        emphasis.tokens = remade_batch_tokens
282
        emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device)
283
        emphasis.z = z
284

285
        emphasis.after_transformers()
286

287
        z = emphasis.z
288

289
        if pooled is not None:
290
            z.pooled = pooled
291

292
        return z
293

294

295
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
296
    def __init__(self, wrapped, hijack):
297
        super().__init__(wrapped, hijack)
298
        self.tokenizer = wrapped.tokenizer
299

300
        vocab = self.tokenizer.get_vocab()
301

302
        self.comma_token = vocab.get(',</w>', None)
303

304
        self.token_mults = {}
305
        tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
306
        for text, ident in tokens_with_parens:
307
            mult = 1.0
308
            for c in text:
309
                if c == '[':
310
                    mult /= 1.1
311
                if c == ']':
312
                    mult *= 1.1
313
                if c == '(':
314
                    mult *= 1.1
315
                if c == ')':
316
                    mult /= 1.1
317

318
            if mult != 1.0:
319
                self.token_mults[ident] = mult
320

321
        self.id_start = self.wrapped.tokenizer.bos_token_id
322
        self.id_end = self.wrapped.tokenizer.eos_token_id
323
        self.id_pad = self.id_end
324

325
    def tokenize(self, texts):
326
        tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
327

328
        return tokenized
329

330
    def encode_with_transformers(self, tokens):
331
        outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
332

333
        if opts.CLIP_stop_at_last_layers > 1:
334
            z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
335
            z = self.wrapped.transformer.text_model.final_layer_norm(z)
336
        else:
337
            z = outputs.last_hidden_state
338

339
        return z
340

341
    def encode_embedding_init_text(self, init_text, nvpt):
342
        embedding_layer = self.wrapped.transformer.text_model.embeddings
343
        ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
344
        embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
345

346
        return embedded
347

348

349
class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
350
    def __init__(self, wrapped, hijack):
351
        super().__init__(wrapped, hijack)
352

353
    def encode_with_transformers(self, tokens):
354
        outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
355

356
        if self.wrapped.layer == "last":
357
            z = outputs.last_hidden_state
358
        else:
359
            z = outputs.hidden_states[self.wrapped.layer_idx]
360

361
        return z
362

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.