stable-diffusion-webui

Форк
0
/
sd_hijack_clip_old.py 
82 строки · 3.5 Кб
1
from modules import sd_hijack_clip
2
from modules import shared
3

4

5
def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
6
    id_start = self.id_start
7
    id_end = self.id_end
8
    maxlen = self.wrapped.max_length  # you get to stay at 77
9
    used_custom_terms = []
10
    remade_batch_tokens = []
11
    hijack_comments = []
12
    hijack_fixes = []
13
    token_count = 0
14

15
    cache = {}
16
    batch_tokens = self.tokenize(texts)
17
    batch_multipliers = []
18
    for tokens in batch_tokens:
19
        tuple_tokens = tuple(tokens)
20

21
        if tuple_tokens in cache:
22
            remade_tokens, fixes, multipliers = cache[tuple_tokens]
23
        else:
24
            fixes = []
25
            remade_tokens = []
26
            multipliers = []
27
            mult = 1.0
28

29
            i = 0
30
            while i < len(tokens):
31
                token = tokens[i]
32

33
                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
34

35
                mult_change = self.token_mults.get(token) if shared.opts.emphasis != "None" else None
36
                if mult_change is not None:
37
                    mult *= mult_change
38
                    i += 1
39
                elif embedding is None:
40
                    remade_tokens.append(token)
41
                    multipliers.append(mult)
42
                    i += 1
43
                else:
44
                    emb_len = int(embedding.vec.shape[0])
45
                    fixes.append((len(remade_tokens), embedding))
46
                    remade_tokens += [0] * emb_len
47
                    multipliers += [mult] * emb_len
48
                    used_custom_terms.append((embedding.name, embedding.checksum()))
49
                    i += embedding_length_in_tokens
50

51
            if len(remade_tokens) > maxlen - 2:
52
                vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
53
                ovf = remade_tokens[maxlen - 2:]
54
                overflowing_words = [vocab.get(int(x), "") for x in ovf]
55
                overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
56
                hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
57

58
            token_count = len(remade_tokens)
59
            remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
60
            remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
61
            cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
62

63
        multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
64
        multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
65

66
        remade_batch_tokens.append(remade_tokens)
67
        hijack_fixes.append(fixes)
68
        batch_multipliers.append(multipliers)
69
    return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
70

71

72
def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
73
    batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
74

75
    self.hijack.comments += hijack_comments
76

77
    if used_custom_terms:
78
        embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
79
        self.hijack.comments.append(f"Used embeddings: {embedding_names}")
80

81
    self.hijack.fixes = hijack_fixes
82
    return self.process_tokens(remade_batch_tokens, batch_multipliers)
83

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.