stable-diffusion-webui

Форк
0
/
prompts_from_file.py 
191 строка · 6.4 Кб
1
import copy
2
import random
3
import shlex
4

5
import modules.scripts as scripts
6
import gradio as gr
7

8
from modules import sd_samplers, errors, sd_models
9
from modules.processing import Processed, process_images
10
from modules.shared import state
11

12

13
def process_model_tag(tag):
14
    info = sd_models.get_closet_checkpoint_match(tag)
15
    assert info is not None, f'Unknown checkpoint: {tag}'
16
    return info.name
17

18

19
def process_string_tag(tag):
20
    return tag
21

22

23
def process_int_tag(tag):
24
    return int(tag)
25

26

27
def process_float_tag(tag):
28
    return float(tag)
29

30

31
def process_boolean_tag(tag):
32
    return True if (tag == "true") else False
33

34

35
prompt_tags = {
36
    "sd_model": process_model_tag,
37
    "outpath_samples": process_string_tag,
38
    "outpath_grids": process_string_tag,
39
    "prompt_for_display": process_string_tag,
40
    "prompt": process_string_tag,
41
    "negative_prompt": process_string_tag,
42
    "styles": process_string_tag,
43
    "seed": process_int_tag,
44
    "subseed_strength": process_float_tag,
45
    "subseed": process_int_tag,
46
    "seed_resize_from_h": process_int_tag,
47
    "seed_resize_from_w": process_int_tag,
48
    "sampler_index": process_int_tag,
49
    "sampler_name": process_string_tag,
50
    "batch_size": process_int_tag,
51
    "n_iter": process_int_tag,
52
    "steps": process_int_tag,
53
    "cfg_scale": process_float_tag,
54
    "width": process_int_tag,
55
    "height": process_int_tag,
56
    "restore_faces": process_boolean_tag,
57
    "tiling": process_boolean_tag,
58
    "do_not_save_samples": process_boolean_tag,
59
    "do_not_save_grid": process_boolean_tag
60
}
61

62

63
def cmdargs(line):
64
    args = shlex.split(line)
65
    pos = 0
66
    res = {}
67

68
    while pos < len(args):
69
        arg = args[pos]
70

71
        assert arg.startswith("--"), f'must start with "--": {arg}'
72
        assert pos+1 < len(args), f'missing argument for command line option {arg}'
73

74
        tag = arg[2:]
75

76
        if tag == "prompt" or tag == "negative_prompt":
77
            pos += 1
78
            prompt = args[pos]
79
            pos += 1
80
            while pos < len(args) and not args[pos].startswith("--"):
81
                prompt += " "
82
                prompt += args[pos]
83
                pos += 1
84
            res[tag] = prompt
85
            continue
86

87

88
        func = prompt_tags.get(tag, None)
89
        assert func, f'unknown commandline option: {arg}'
90

91
        val = args[pos+1]
92
        if tag == "sampler_name":
93
            val = sd_samplers.samplers_map.get(val.lower(), None)
94

95
        res[tag] = func(val)
96

97
        pos += 2
98

99
    return res
100

101

102
def load_prompt_file(file):
103
    if file is None:
104
        return None, gr.update(), gr.update(lines=7)
105
    else:
106
        lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
107
        return None, "\n".join(lines), gr.update(lines=7)
108

109

110
class Script(scripts.Script):
111
    def title(self):
112
        return "Prompts from file or textbox"
113

114
    def ui(self, is_img2img):
115
        checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
116
        checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
117
        prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start")
118

119
        prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
120
        file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
121

122
        file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt], show_progress=False)
123

124
        # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n.
125
        # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
126
        # be unclear to the user that shift-enter is needed.
127
        prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
128
        return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt]
129

130
    def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str):
131
        lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
132

133
        p.do_not_save_grid = True
134

135
        job_count = 0
136
        jobs = []
137

138
        for line in lines:
139
            if "--" in line:
140
                try:
141
                    args = cmdargs(line)
142
                except Exception:
143
                    errors.report(f"Error parsing line {line} as commandline", exc_info=True)
144
                    args = {"prompt": line}
145
            else:
146
                args = {"prompt": line}
147

148
            job_count += args.get("n_iter", p.n_iter)
149

150
            jobs.append(args)
151

152
        print(f"Will process {len(lines)} lines in {job_count} jobs.")
153
        if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1:
154
            p.seed = int(random.randrange(4294967294))
155

156
        state.job_count = job_count
157

158
        images = []
159
        all_prompts = []
160
        infotexts = []
161
        for args in jobs:
162
            state.job = f"{state.job_no + 1} out of {state.job_count}"
163

164
            copy_p = copy.copy(p)
165
            for k, v in args.items():
166
                if k == "sd_model":
167
                    copy_p.override_settings['sd_model_checkpoint'] = v
168
                else:
169
                    setattr(copy_p, k, v)
170

171
            if args.get("prompt") and p.prompt:
172
                if prompt_position == "start":
173
                    copy_p.prompt = args.get("prompt") + " " + p.prompt
174
                else:
175
                    copy_p.prompt = p.prompt + " " + args.get("prompt")
176

177
            if args.get("negative_prompt") and p.negative_prompt:
178
                if prompt_position == "start":
179
                    copy_p.negative_prompt = args.get("negative_prompt") + " " + p.negative_prompt
180
                else:
181
                    copy_p.negative_prompt = p.negative_prompt + " " + args.get("negative_prompt")
182

183
            proc = process_images(copy_p)
184
            images += proc.images
185

186
            if checkbox_iterate:
187
                p.seed = p.seed + (p.batch_size * p.n_iter)
188
            all_prompts += proc.all_prompts
189
            infotexts += proc.infotexts
190

191
        return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)
192

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.