stable-diffusion-webui

Форк
0
233 строки · 9.3 Кб
1
from pathlib import Path
2
from modules import errors
3
import csv
4
import os
5
import typing
6
import shutil
7

8

9
class PromptStyle(typing.NamedTuple):
10
    name: str
11
    prompt: str | None
12
    negative_prompt: str | None
13
    path: str | None = None
14

15

16
def merge_prompts(style_prompt: str, prompt: str) -> str:
17
    if "{prompt}" in style_prompt:
18
        res = style_prompt.replace("{prompt}", prompt)
19
    else:
20
        parts = filter(None, (prompt.strip(), style_prompt.strip()))
21
        res = ", ".join(parts)
22

23
    return res
24

25

26
def apply_styles_to_prompt(prompt, styles):
27
    for style in styles:
28
        prompt = merge_prompts(style, prompt)
29

30
    return prompt
31

32

33
def extract_style_text_from_prompt(style_text, prompt):
34
    """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
35

36
    extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
37
    extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
38
    extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
39
    """
40

41
    stripped_prompt = prompt.strip()
42
    stripped_style_text = style_text.strip()
43

44
    if "{prompt}" in stripped_style_text:
45
        left, right = stripped_style_text.split("{prompt}", 2)
46
        if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
47
            prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
48
            return True, prompt
49
    else:
50
        if stripped_prompt.endswith(stripped_style_text):
51
            prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
52

53
            if prompt.endswith(', '):
54
                prompt = prompt[:-2]
55

56
            return True, prompt
57

58
    return False, prompt
59

60

61
def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
62
    """
63
    Takes a style and compares it to the prompt and negative prompt. If the style
64
    matches, returns True plus the prompt and negative prompt with the style text
65
    removed. Otherwise, returns False with the original prompt and negative prompt.
66
    """
67
    if not style.prompt and not style.negative_prompt:
68
        return False, prompt, negative_prompt
69

70
    match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
71
    if not match_positive:
72
        return False, prompt, negative_prompt
73

74
    match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
75
    if not match_negative:
76
        return False, prompt, negative_prompt
77

78
    return True, extracted_positive, extracted_negative
79

80

81
class StyleDatabase:
82
    def __init__(self, paths: list[str | Path]):
83
        self.no_style = PromptStyle("None", "", "", None)
84
        self.styles = {}
85
        self.paths = paths
86
        self.all_styles_files: list[Path] = []
87

88
        folder, file = os.path.split(self.paths[0])
89
        if '*' in file or '?' in file:
90
            # if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
91
            self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
92
            self.paths.insert(0, self.default_path)
93
        else:
94
            self.default_path = Path(self.paths[0])
95

96
        self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
97

98
        self.reload()
99

100
    def reload(self):
101
        """
102
        Clears the style database and reloads the styles from the CSV file(s)
103
        matching the path used to initialize the database.
104
        """
105
        self.styles.clear()
106

107
        # scans for all styles files
108
        all_styles_files = []
109
        for pattern in self.paths:
110
            folder, file = os.path.split(pattern)
111
            if '*' in file or '?' in file:
112
                found_files = Path(folder).glob(file)
113
                [all_styles_files.append(file) for file in found_files]
114
            else:
115
                # if os.path.exists(pattern):
116
                all_styles_files.append(Path(pattern))
117

118
        # Remove any duplicate entries
119
        seen = set()
120
        self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
121

122
        for styles_file in self.all_styles_files:
123
            if len(all_styles_files) > 1:
124
                # add divider when more than styles file
125
                # '---------------- STYLES ----------------'
126
                divider = f' {styles_file.stem.upper()} '.center(40, '-')
127
                self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
128
            if styles_file.is_file():
129
                self.load_from_csv(styles_file)
130

131
    def load_from_csv(self, path: str | Path):
132
        try:
133
            with open(path, "r", encoding="utf-8-sig", newline="") as file:
134
                reader = csv.DictReader(file, skipinitialspace=True)
135
                for row in reader:
136
                    # Ignore empty rows or rows starting with a comment
137
                    if not row or row["name"].startswith("#"):
138
                        continue
139
                    # Support loading old CSV format with "name, text"-columns
140
                    prompt = row["prompt"] if "prompt" in row else row["text"]
141
                    negative_prompt = row.get("negative_prompt", "")
142
                    # Add style to database
143
                    self.styles[row["name"]] = PromptStyle(
144
                        row["name"], prompt, negative_prompt, str(path)
145
                    )
146
        except Exception:
147
            errors.report(f'Error loading styles from {path}: ', exc_info=True)
148

149
    def get_style_paths(self) -> set:
150
        """Returns a set of all distinct paths of files that styles are loaded from."""
151
        # Update any styles without a path to the default path
152
        for style in list(self.styles.values()):
153
            if not style.path:
154
                self.styles[style.name] = style._replace(path=str(self.default_path))
155

156
        # Create a list of all distinct paths, including the default path
157
        style_paths = set()
158
        style_paths.add(str(self.default_path))
159
        for _, style in self.styles.items():
160
            if style.path:
161
                style_paths.add(style.path)
162

163
        # Remove any paths for styles that are just list dividers
164
        style_paths.discard("do_not_save")
165

166
        return style_paths
167

168
    def get_style_prompts(self, styles):
169
        return [self.styles.get(x, self.no_style).prompt for x in styles]
170

171
    def get_negative_style_prompts(self, styles):
172
        return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
173

174
    def apply_styles_to_prompt(self, prompt, styles):
175
        return apply_styles_to_prompt(
176
            prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
177
        )
178

179
    def apply_negative_styles_to_prompt(self, prompt, styles):
180
        return apply_styles_to_prompt(
181
            prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
182
        )
183

184
    def save_styles(self, path: str = None) -> None:
185
        # The path argument is deprecated, but kept for backwards compatibility
186

187
        style_paths = self.get_style_paths()
188

189
        csv_names = [os.path.split(path)[1].lower() for path in style_paths]
190

191
        for style_path in style_paths:
192
            # Always keep a backup file around
193
            if os.path.exists(style_path):
194
                shutil.copy(style_path, f"{style_path}.bak")
195

196
            # Write the styles to the CSV file
197
            with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
198
                writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
199
                writer.writeheader()
200
                for style in (s for s in self.styles.values() if s.path == style_path):
201
                    # Skip style list dividers, e.g. "STYLES.CSV"
202
                    if style.name.lower().strip("# ") in csv_names:
203
                        continue
204
                    # Write style fields, ignoring the path field
205
                    writer.writerow(
206
                        {k: v for k, v in style._asdict().items() if k != "path"}
207
                    )
208

209
    def extract_styles_from_prompt(self, prompt, negative_prompt):
210
        extracted = []
211

212
        applicable_styles = list(self.styles.values())
213

214
        while True:
215
            found_style = None
216

217
            for style in applicable_styles:
218
                is_match, new_prompt, new_neg_prompt = extract_original_prompts(
219
                    style, prompt, negative_prompt
220
                )
221
                if is_match:
222
                    found_style = style
223
                    prompt = new_prompt
224
                    negative_prompt = new_neg_prompt
225
                    break
226

227
            if not found_style:
228
                break
229

230
            applicable_styles.remove(found_style)
231
            extracted.append(found_style.name)
232

233
        return list(reversed(extracted)), prompt, negative_prompt
234

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.