stable-diffusion-webui
82 строки · 3.5 Кб
1from modules import sd_hijack_clip
2from modules import shared
3
4
5def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
6id_start = self.id_start
7id_end = self.id_end
8maxlen = self.wrapped.max_length # you get to stay at 77
9used_custom_terms = []
10remade_batch_tokens = []
11hijack_comments = []
12hijack_fixes = []
13token_count = 0
14
15cache = {}
16batch_tokens = self.tokenize(texts)
17batch_multipliers = []
18for tokens in batch_tokens:
19tuple_tokens = tuple(tokens)
20
21if tuple_tokens in cache:
22remade_tokens, fixes, multipliers = cache[tuple_tokens]
23else:
24fixes = []
25remade_tokens = []
26multipliers = []
27mult = 1.0
28
29i = 0
30while i < len(tokens):
31token = tokens[i]
32
33embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
34
35mult_change = self.token_mults.get(token) if shared.opts.emphasis != "None" else None
36if mult_change is not None:
37mult *= mult_change
38i += 1
39elif embedding is None:
40remade_tokens.append(token)
41multipliers.append(mult)
42i += 1
43else:
44emb_len = int(embedding.vec.shape[0])
45fixes.append((len(remade_tokens), embedding))
46remade_tokens += [0] * emb_len
47multipliers += [mult] * emb_len
48used_custom_terms.append((embedding.name, embedding.checksum()))
49i += embedding_length_in_tokens
50
51if len(remade_tokens) > maxlen - 2:
52vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
53ovf = remade_tokens[maxlen - 2:]
54overflowing_words = [vocab.get(int(x), "") for x in ovf]
55overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
56hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
57
58token_count = len(remade_tokens)
59remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
60remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
61cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
62
63multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
64multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
65
66remade_batch_tokens.append(remade_tokens)
67hijack_fixes.append(fixes)
68batch_multipliers.append(multipliers)
69return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
70
71
72def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
73batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
74
75self.hijack.comments += hijack_comments
76
77if used_custom_terms:
78embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
79self.hijack.comments.append(f"Used embeddings: {embedding_names}")
80
81self.hijack.fixes = hijack_fixes
82return self.process_tokens(remade_batch_tokens, batch_multipliers)
83