stable-diffusion-webui
233 строки · 9.3 Кб
1from pathlib import Path2from modules import errors3import csv4import os5import typing6import shutil7
8
9class PromptStyle(typing.NamedTuple):10name: str11prompt: str | None12negative_prompt: str | None13path: str | None = None14
15
16def merge_prompts(style_prompt: str, prompt: str) -> str:17if "{prompt}" in style_prompt:18res = style_prompt.replace("{prompt}", prompt)19else:20parts = filter(None, (prompt.strip(), style_prompt.strip()))21res = ", ".join(parts)22
23return res24
25
26def apply_styles_to_prompt(prompt, styles):27for style in styles:28prompt = merge_prompts(style, prompt)29
30return prompt31
32
33def extract_style_text_from_prompt(style_text, prompt):34"""This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.35
36extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
37extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
38extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
39"""
40
41stripped_prompt = prompt.strip()42stripped_style_text = style_text.strip()43
44if "{prompt}" in stripped_style_text:45left, right = stripped_style_text.split("{prompt}", 2)46if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):47prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]48return True, prompt49else:50if stripped_prompt.endswith(stripped_style_text):51prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]52
53if prompt.endswith(', '):54prompt = prompt[:-2]55
56return True, prompt57
58return False, prompt59
60
61def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):62"""63Takes a style and compares it to the prompt and negative prompt. If the style
64matches, returns True plus the prompt and negative prompt with the style text
65removed. Otherwise, returns False with the original prompt and negative prompt.
66"""
67if not style.prompt and not style.negative_prompt:68return False, prompt, negative_prompt69
70match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)71if not match_positive:72return False, prompt, negative_prompt73
74match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)75if not match_negative:76return False, prompt, negative_prompt77
78return True, extracted_positive, extracted_negative79
80
81class StyleDatabase:82def __init__(self, paths: list[str | Path]):83self.no_style = PromptStyle("None", "", "", None)84self.styles = {}85self.paths = paths86self.all_styles_files: list[Path] = []87
88folder, file = os.path.split(self.paths[0])89if '*' in file or '?' in file:90# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path91self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))92self.paths.insert(0, self.default_path)93else:94self.default_path = Path(self.paths[0])95
96self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]97
98self.reload()99
100def reload(self):101"""102Clears the style database and reloads the styles from the CSV file(s)
103matching the path used to initialize the database.
104"""
105self.styles.clear()106
107# scans for all styles files108all_styles_files = []109for pattern in self.paths:110folder, file = os.path.split(pattern)111if '*' in file or '?' in file:112found_files = Path(folder).glob(file)113[all_styles_files.append(file) for file in found_files]114else:115# if os.path.exists(pattern):116all_styles_files.append(Path(pattern))117
118# Remove any duplicate entries119seen = set()120self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]121
122for styles_file in self.all_styles_files:123if len(all_styles_files) > 1:124# add divider when more than styles file125# '---------------- STYLES ----------------'126divider = f' {styles_file.stem.upper()} '.center(40, '-')127self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")128if styles_file.is_file():129self.load_from_csv(styles_file)130
131def load_from_csv(self, path: str | Path):132try:133with open(path, "r", encoding="utf-8-sig", newline="") as file:134reader = csv.DictReader(file, skipinitialspace=True)135for row in reader:136# Ignore empty rows or rows starting with a comment137if not row or row["name"].startswith("#"):138continue139# Support loading old CSV format with "name, text"-columns140prompt = row["prompt"] if "prompt" in row else row["text"]141negative_prompt = row.get("negative_prompt", "")142# Add style to database143self.styles[row["name"]] = PromptStyle(144row["name"], prompt, negative_prompt, str(path)145)146except Exception:147errors.report(f'Error loading styles from {path}: ', exc_info=True)148
149def get_style_paths(self) -> set:150"""Returns a set of all distinct paths of files that styles are loaded from."""151# Update any styles without a path to the default path152for style in list(self.styles.values()):153if not style.path:154self.styles[style.name] = style._replace(path=str(self.default_path))155
156# Create a list of all distinct paths, including the default path157style_paths = set()158style_paths.add(str(self.default_path))159for _, style in self.styles.items():160if style.path:161style_paths.add(style.path)162
163# Remove any paths for styles that are just list dividers164style_paths.discard("do_not_save")165
166return style_paths167
168def get_style_prompts(self, styles):169return [self.styles.get(x, self.no_style).prompt for x in styles]170
171def get_negative_style_prompts(self, styles):172return [self.styles.get(x, self.no_style).negative_prompt for x in styles]173
174def apply_styles_to_prompt(self, prompt, styles):175return apply_styles_to_prompt(176prompt, [self.styles.get(x, self.no_style).prompt for x in styles]177)178
179def apply_negative_styles_to_prompt(self, prompt, styles):180return apply_styles_to_prompt(181prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]182)183
184def save_styles(self, path: str = None) -> None:185# The path argument is deprecated, but kept for backwards compatibility186
187style_paths = self.get_style_paths()188
189csv_names = [os.path.split(path)[1].lower() for path in style_paths]190
191for style_path in style_paths:192# Always keep a backup file around193if os.path.exists(style_path):194shutil.copy(style_path, f"{style_path}.bak")195
196# Write the styles to the CSV file197with open(style_path, "w", encoding="utf-8-sig", newline="") as file:198writer = csv.DictWriter(file, fieldnames=self.prompt_fields)199writer.writeheader()200for style in (s for s in self.styles.values() if s.path == style_path):201# Skip style list dividers, e.g. "STYLES.CSV"202if style.name.lower().strip("# ") in csv_names:203continue204# Write style fields, ignoring the path field205writer.writerow(206{k: v for k, v in style._asdict().items() if k != "path"}207)208
209def extract_styles_from_prompt(self, prompt, negative_prompt):210extracted = []211
212applicable_styles = list(self.styles.values())213
214while True:215found_style = None216
217for style in applicable_styles:218is_match, new_prompt, new_neg_prompt = extract_original_prompts(219style, prompt, negative_prompt220)221if is_match:222found_style = style223prompt = new_prompt224negative_prompt = new_neg_prompt225break226
227if not found_style:228break229
230applicable_styles.remove(found_style)231extracted.append(found_style.name)232
233return list(reversed(extracted)), prompt, negative_prompt234