stable-diffusion-webui
812 строк · 37.8 Кб
1from collections import namedtuple2from copy import copy3from itertools import permutations, chain4import random5import csv6import os.path7from io import StringIO8from PIL import Image9import numpy as np10
11import modules.scripts as scripts12import gradio as gr13
14from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion, errors15from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img16from modules.shared import opts, state17import modules.shared as shared18import modules.sd_samplers19import modules.sd_models20import modules.sd_vae21import re22
23from modules.ui_components import ToolButton24
25fill_values_symbol = "\U0001f4d2" # 📒26
27AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])28
29
30def apply_field(field):31def fun(p, x, xs):32setattr(p, field, x)33
34return fun35
36
37def apply_prompt(p, x, xs):38if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:39raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")40
41p.prompt = p.prompt.replace(xs[0], x)42p.negative_prompt = p.negative_prompt.replace(xs[0], x)43
44
45def apply_order(p, x, xs):46token_order = []47
48# Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen49for token in x:50token_order.append((p.prompt.find(token), token))51
52token_order.sort(key=lambda t: t[0])53
54prompt_parts = []55
56# Split the prompt up, taking out the tokens57for _, token in token_order:58n = p.prompt.find(token)59prompt_parts.append(p.prompt[0:n])60p.prompt = p.prompt[n + len(token):]61
62# Rebuild the prompt with the tokens in the order we want63prompt_tmp = ""64for idx, part in enumerate(prompt_parts):65prompt_tmp += part66prompt_tmp += x[idx]67p.prompt = prompt_tmp + p.prompt68
69
70def confirm_samplers(p, xs):71for x in xs:72if x.lower() not in sd_samplers.samplers_map:73raise RuntimeError(f"Unknown sampler: {x}")74
75
76def apply_checkpoint(p, x, xs):77info = modules.sd_models.get_closet_checkpoint_match(x)78if info is None:79raise RuntimeError(f"Unknown checkpoint: {x}")80p.override_settings['sd_model_checkpoint'] = info.name81
82
83def confirm_checkpoints(p, xs):84for x in xs:85if modules.sd_models.get_closet_checkpoint_match(x) is None:86raise RuntimeError(f"Unknown checkpoint: {x}")87
88
89def confirm_checkpoints_or_none(p, xs):90for x in xs:91if x in (None, "", "None", "none"):92continue93
94if modules.sd_models.get_closet_checkpoint_match(x) is None:95raise RuntimeError(f"Unknown checkpoint: {x}")96
97
98def apply_clip_skip(p, x, xs):99opts.data["CLIP_stop_at_last_layers"] = x100
101
102def apply_upscale_latent_space(p, x, xs):103if x.lower().strip() != '0':104opts.data["use_scale_latent_for_hires_fix"] = True105else:106opts.data["use_scale_latent_for_hires_fix"] = False107
108
109def find_vae(name: str):110if name.lower() in ['auto', 'automatic']:111return modules.sd_vae.unspecified112if name.lower() == 'none':113return None114else:115choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]116if len(choices) == 0:117print(f"No VAE found for {name}; using automatic")118return modules.sd_vae.unspecified119else:120return modules.sd_vae.vae_dict[choices[0]]121
122
123def apply_vae(p, x, xs):124modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))125
126
127def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):128p.styles.extend(x.split(','))129
130
131def apply_uni_pc_order(p, x, xs):132opts.data["uni_pc_order"] = min(x, p.steps - 1)133
134
135def apply_face_restore(p, opt, x):136opt = opt.lower()137if opt == 'codeformer':138is_active = True139p.face_restoration_model = 'CodeFormer'140elif opt == 'gfpgan':141is_active = True142p.face_restoration_model = 'GFPGAN'143else:144is_active = opt in ('true', 'yes', 'y', '1')145
146p.restore_faces = is_active147
148
149def apply_override(field, boolean: bool = False):150def fun(p, x, xs):151if boolean:152x = True if x.lower() == "true" else False153p.override_settings[field] = x154return fun155
156
157def boolean_choice(reverse: bool = False):158def choice():159return ["False", "True"] if reverse else ["True", "False"]160return choice161
162
163def format_value_add_label(p, opt, x):164if type(x) == float:165x = round(x, 8)166
167return f"{opt.label}: {x}"168
169
170def format_value(p, opt, x):171if type(x) == float:172x = round(x, 8)173return x174
175
176def format_value_join_list(p, opt, x):177return ", ".join(x)178
179
180def do_nothing(p, x, xs):181pass182
183
184def format_nothing(p, opt, x):185return ""186
187
188def format_remove_path(p, opt, x):189return os.path.basename(x)190
191
192def str_permutations(x):193"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""194return x195
196
197def list_to_csv_string(data_list):198with StringIO() as o:199csv.writer(o).writerow(data_list)200return o.getvalue().strip()201
202
203def csv_string_to_list_strip(data_str):204return list(map(str.strip, chain.from_iterable(csv.reader(StringIO(data_str)))))205
206
207class AxisOption:208def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None, prepare=None):209self.label = label210self.type = type211self.apply = apply212self.format_value = format_value213self.confirm = confirm214self.cost = cost215self.prepare = prepare216self.choices = choices217
218
219class AxisOptionImg2Img(AxisOption):220def __init__(self, *args, **kwargs):221super().__init__(*args, **kwargs)222self.is_img2img = True223
224
225class AxisOptionTxt2Img(AxisOption):226def __init__(self, *args, **kwargs):227super().__init__(*args, **kwargs)228self.is_img2img = False229
230
231axis_options = [232AxisOption("Nothing", str, do_nothing, format_value=format_nothing),233AxisOption("Seed", int, apply_field("seed")),234AxisOption("Var. seed", int, apply_field("subseed")),235AxisOption("Var. strength", float, apply_field("subseed_strength")),236AxisOption("Steps", int, apply_field("steps")),237AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),238AxisOption("CFG Scale", float, apply_field("cfg_scale")),239AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),240AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),241AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),242AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers if x.name not in opts.hide_samplers]),243AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]),244AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]),245AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),246AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),247AxisOption("Sigma Churn", float, apply_field("s_churn")),248AxisOption("Sigma min", float, apply_field("s_tmin")),249AxisOption("Sigma max", float, apply_field("s_tmax")),250AxisOption("Sigma noise", float, apply_field("s_noise")),251AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)),252AxisOption("Schedule min sigma", float, apply_override("sigma_min")),253AxisOption("Schedule max sigma", float, apply_override("sigma_max")),254AxisOption("Schedule rho", float, apply_override("rho")),255AxisOption("Eta", float, apply_field("eta")),256AxisOption("Clip skip", int, apply_clip_skip),257AxisOption("Denoising", float, apply_field("denoising_strength")),258AxisOption("Initial noise multiplier", float, apply_field("initial_noise_multiplier")),259AxisOption("Extra noise", float, apply_override("img2img_extra_noise")),260AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),261AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),262AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)),263AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),264AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),265AxisOption("Face restore", str, apply_face_restore, format_value=format_value),266AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),267AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),268AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)),269AxisOption("SGM noise multiplier", str, apply_override('sgm_noise_multiplier', boolean=True), choices=boolean_choice(reverse=True)),270AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),271AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),272AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),273AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]),274]
275
276
277def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):278hor_texts = [[images.GridAnnotation(x)] for x in x_labels]279ver_texts = [[images.GridAnnotation(y)] for y in y_labels]280title_texts = [[images.GridAnnotation(z)] for z in z_labels]281
282list_size = (len(xs) * len(ys) * len(zs))283
284processed_result = None285
286state.job_count = list_size * p.n_iter287
288def process_cell(x, y, z, ix, iy, iz):289nonlocal processed_result290
291def index(ix, iy, iz):292return ix + iy * len(xs) + iz * len(xs) * len(ys)293
294state.job = f"{index(ix, iy, iz) + 1} out of {list_size}"295
296processed: Processed = cell(x, y, z, ix, iy, iz)297
298if processed_result is None:299# Use our first processed result object as a template container to hold our full results300processed_result = copy(processed)301processed_result.images = [None] * list_size302processed_result.all_prompts = [None] * list_size303processed_result.all_seeds = [None] * list_size304processed_result.infotexts = [None] * list_size305processed_result.index_of_first_image = 1306
307idx = index(ix, iy, iz)308if processed.images:309# Non-empty list indicates some degree of success.310processed_result.images[idx] = processed.images[0]311processed_result.all_prompts[idx] = processed.prompt312processed_result.all_seeds[idx] = processed.seed313processed_result.infotexts[idx] = processed.infotexts[0]314else:315cell_mode = "P"316cell_size = (processed_result.width, processed_result.height)317if processed_result.images[0] is not None:318cell_mode = processed_result.images[0].mode319# This corrects size in case of batches:320cell_size = processed_result.images[0].size321processed_result.images[idx] = Image.new(cell_mode, cell_size)322
323if first_axes_processed == 'x':324for ix, x in enumerate(xs):325if second_axes_processed == 'y':326for iy, y in enumerate(ys):327for iz, z in enumerate(zs):328process_cell(x, y, z, ix, iy, iz)329else:330for iz, z in enumerate(zs):331for iy, y in enumerate(ys):332process_cell(x, y, z, ix, iy, iz)333elif first_axes_processed == 'y':334for iy, y in enumerate(ys):335if second_axes_processed == 'x':336for ix, x in enumerate(xs):337for iz, z in enumerate(zs):338process_cell(x, y, z, ix, iy, iz)339else:340for iz, z in enumerate(zs):341for ix, x in enumerate(xs):342process_cell(x, y, z, ix, iy, iz)343elif first_axes_processed == 'z':344for iz, z in enumerate(zs):345if second_axes_processed == 'x':346for ix, x in enumerate(xs):347for iy, y in enumerate(ys):348process_cell(x, y, z, ix, iy, iz)349else:350for iy, y in enumerate(ys):351for ix, x in enumerate(xs):352process_cell(x, y, z, ix, iy, iz)353
354if not processed_result:355# Should never happen, I've only seen it on one of four open tabs and it needed to refresh.356print("Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service.")357return Processed(p, [])358elif not any(processed_result.images):359print("Unexpected error: draw_xyz_grid failed to return even a single processed image")360return Processed(p, [])361
362z_count = len(zs)363
364for i in range(z_count):365start_index = (i * len(xs) * len(ys)) + i366end_index = start_index + len(xs) * len(ys)367grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys))368if draw_legend:369grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size)370processed_result.images.insert(i, grid)371processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index])372processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index])373processed_result.infotexts.insert(i, processed_result.infotexts[start_index])374
375sub_grid_size = processed_result.images[0].size376z_grid = images.image_grid(processed_result.images[:z_count], rows=1)377if draw_legend:378z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])379processed_result.images.insert(0, z_grid)380# TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.381# processed_result.all_prompts.insert(0, processed_result.all_prompts[0])382# processed_result.all_seeds.insert(0, processed_result.all_seeds[0])383processed_result.infotexts.insert(0, processed_result.infotexts[0])384
385return processed_result386
387
388class SharedSettingsStackHelper(object):389def __enter__(self):390self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers391self.vae = opts.sd_vae392self.uni_pc_order = opts.uni_pc_order393
394def __exit__(self, exc_type, exc_value, tb):395opts.data["sd_vae"] = self.vae396opts.data["uni_pc_order"] = self.uni_pc_order397modules.sd_models.reload_model_weights()398modules.sd_vae.reload_vae_weights()399
400opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers401
402
403re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")404re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")405
406re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*])?\s*")407re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*])?\s*")408
409
410class Script(scripts.Script):411def title(self):412return "X/Y/Z plot"413
414def ui(self, is_img2img):415self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]416
417with gr.Row():418with gr.Column(scale=19):419with gr.Row():420x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))421x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))422x_values_dropdown = gr.Dropdown(label="X values", visible=False, multiselect=True, interactive=True)423fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)424
425with gr.Row():426y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))427y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))428y_values_dropdown = gr.Dropdown(label="Y values", visible=False, multiselect=True, interactive=True)429fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)430
431with gr.Row():432z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))433z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))434z_values_dropdown = gr.Dropdown(label="Z values", visible=False, multiselect=True, interactive=True)435fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)436
437with gr.Row(variant="compact", elem_id="axis_options"):438with gr.Column():439draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))440no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))441with gr.Row():442vary_seeds_x = gr.Checkbox(label='Vary seeds for X', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_x"), tooltip="Use different seeds for images along X axis.")443vary_seeds_y = gr.Checkbox(label='Vary seeds for Y', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_y"), tooltip="Use different seeds for images along Y axis.")444vary_seeds_z = gr.Checkbox(label='Vary seeds for Z', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_z"), tooltip="Use different seeds for images along Z axis.")445with gr.Column():446include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))447include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))448csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))449with gr.Column():450margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))451
452with gr.Row(variant="compact", elem_id="swap_axes"):453swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")454swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")455swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")456
457def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown):458return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown459
460xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown]461swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)462yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown]463swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)464xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]465swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)466
467def fill(axis_type, csv_mode):468axis = self.current_axis_options[axis_type]469if axis.choices:470if csv_mode:471return list_to_csv_string(axis.choices()), gr.update()472else:473return gr.update(), axis.choices()474else:475return gr.update(), gr.update()476
477fill_x_button.click(fn=fill, inputs=[x_type, csv_mode], outputs=[x_values, x_values_dropdown])478fill_y_button.click(fn=fill, inputs=[y_type, csv_mode], outputs=[y_values, y_values_dropdown])479fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown])480
481def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode):482axis_type = axis_type or 0 # if axle type is None set to 0483
484choices = self.current_axis_options[axis_type].choices485has_choices = choices is not None486
487if has_choices:488choices = choices()489if csv_mode:490if axis_values_dropdown:491axis_values = list_to_csv_string(list(filter(lambda x: x in choices, axis_values_dropdown)))492axis_values_dropdown = []493else:494if axis_values:495axis_values_dropdown = list(filter(lambda x: x in choices, csv_string_to_list_strip(axis_values)))496axis_values = ""497
498return (gr.Button.update(visible=has_choices), gr.Textbox.update(visible=not has_choices or csv_mode, value=axis_values),499gr.update(choices=choices if has_choices else None, visible=has_choices and not csv_mode, value=axis_values_dropdown))500
501x_type.change(fn=select_axis, inputs=[x_type, x_values, x_values_dropdown, csv_mode], outputs=[fill_x_button, x_values, x_values_dropdown])502y_type.change(fn=select_axis, inputs=[y_type, y_values, y_values_dropdown, csv_mode], outputs=[fill_y_button, y_values, y_values_dropdown])503z_type.change(fn=select_axis, inputs=[z_type, z_values, z_values_dropdown, csv_mode], outputs=[fill_z_button, z_values, z_values_dropdown])504
505def change_choice_mode(csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown):506_fill_x_button, _x_values, _x_values_dropdown = select_axis(x_type, x_values, x_values_dropdown, csv_mode)507_fill_y_button, _y_values, _y_values_dropdown = select_axis(y_type, y_values, y_values_dropdown, csv_mode)508_fill_z_button, _z_values, _z_values_dropdown = select_axis(z_type, z_values, z_values_dropdown, csv_mode)509return _fill_x_button, _x_values, _x_values_dropdown, _fill_y_button, _y_values, _y_values_dropdown, _fill_z_button, _z_values, _z_values_dropdown510
511csv_mode.change(fn=change_choice_mode, inputs=[csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown], outputs=[fill_x_button, x_values, x_values_dropdown, fill_y_button, y_values, y_values_dropdown, fill_z_button, z_values, z_values_dropdown])512
513def get_dropdown_update_from_params(axis, params):514val_key = f"{axis} Values"515vals = params.get(val_key, "")516valslist = csv_string_to_list_strip(vals)517return gr.update(value=valslist)518
519self.infotext_fields = (520(x_type, "X Type"),521(x_values, "X Values"),522(x_values_dropdown, lambda params: get_dropdown_update_from_params("X", params)),523(y_type, "Y Type"),524(y_values, "Y Values"),525(y_values_dropdown, lambda params: get_dropdown_update_from_params("Y", params)),526(z_type, "Z Type"),527(z_values, "Z Values"),528(z_values_dropdown, lambda params: get_dropdown_update_from_params("Z", params)),529)530
531return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode]532
533def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode):534x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0535
536if not no_fixed_seeds:537modules.processing.fix_seed(p)538
539if not opts.return_grid:540p.batch_size = 1541
542def process_axis(opt, vals, vals_dropdown):543if opt.label == 'Nothing':544return [0]545
546if opt.choices is not None and not csv_mode:547valslist = vals_dropdown548elif opt.prepare is not None:549valslist = opt.prepare(vals)550else:551valslist = csv_string_to_list_strip(vals)552
553if opt.type == int:554valslist_ext = []555
556for val in valslist:557if val.strip() == '':558continue559m = re_range.fullmatch(val)560mc = re_range_count.fullmatch(val)561if m is not None:562start = int(m.group(1))563end = int(m.group(2))+1564step = int(m.group(3)) if m.group(3) is not None else 1565
566valslist_ext += list(range(start, end, step))567elif mc is not None:568start = int(mc.group(1))569end = int(mc.group(2))570num = int(mc.group(3)) if mc.group(3) is not None else 1571
572valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]573else:574valslist_ext.append(val)575
576valslist = valslist_ext577elif opt.type == float:578valslist_ext = []579
580for val in valslist:581if val.strip() == '':582continue583m = re_range_float.fullmatch(val)584mc = re_range_count_float.fullmatch(val)585if m is not None:586start = float(m.group(1))587end = float(m.group(2))588step = float(m.group(3)) if m.group(3) is not None else 1589
590valslist_ext += np.arange(start, end + step, step).tolist()591elif mc is not None:592start = float(mc.group(1))593end = float(mc.group(2))594num = int(mc.group(3)) if mc.group(3) is not None else 1595
596valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()597else:598valslist_ext.append(val)599
600valslist = valslist_ext601elif opt.type == str_permutations:602valslist = list(permutations(valslist))603
604valslist = [opt.type(x) for x in valslist]605
606# Confirm options are valid before starting607if opt.confirm:608opt.confirm(p, valslist)609
610return valslist611
612x_opt = self.current_axis_options[x_type]613if x_opt.choices is not None and not csv_mode:614x_values = list_to_csv_string(x_values_dropdown)615xs = process_axis(x_opt, x_values, x_values_dropdown)616
617y_opt = self.current_axis_options[y_type]618if y_opt.choices is not None and not csv_mode:619y_values = list_to_csv_string(y_values_dropdown)620ys = process_axis(y_opt, y_values, y_values_dropdown)621
622z_opt = self.current_axis_options[z_type]623if z_opt.choices is not None and not csv_mode:624z_values = list_to_csv_string(z_values_dropdown)625zs = process_axis(z_opt, z_values, z_values_dropdown)626
627# this could be moved to common code, but unlikely to be ever triggered anywhere else628Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes629grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)630assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'631
632def fix_axis_seeds(axis_opt, axis_list):633if axis_opt.label in ['Seed', 'Var. seed']:634return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]635else:636return axis_list637
638if not no_fixed_seeds:639xs = fix_axis_seeds(x_opt, xs)640ys = fix_axis_seeds(y_opt, ys)641zs = fix_axis_seeds(z_opt, zs)642
643if x_opt.label == 'Steps':644total_steps = sum(xs) * len(ys) * len(zs)645elif y_opt.label == 'Steps':646total_steps = sum(ys) * len(xs) * len(zs)647elif z_opt.label == 'Steps':648total_steps = sum(zs) * len(xs) * len(ys)649else:650total_steps = p.steps * len(xs) * len(ys) * len(zs)651
652if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:653if x_opt.label == "Hires steps":654total_steps += sum(xs) * len(ys) * len(zs)655elif y_opt.label == "Hires steps":656total_steps += sum(ys) * len(xs) * len(zs)657elif z_opt.label == "Hires steps":658total_steps += sum(zs) * len(xs) * len(ys)659elif p.hr_second_pass_steps:660total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)661else:662total_steps *= 2663
664total_steps *= p.n_iter665
666image_cell_count = p.n_iter * p.batch_size667cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""668plural_s = 's' if len(zs) > 1 else ''669print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")670shared.total_tqdm.updateTotal(total_steps)671
672state.xyz_plot_x = AxisInfo(x_opt, xs)673state.xyz_plot_y = AxisInfo(y_opt, ys)674state.xyz_plot_z = AxisInfo(z_opt, zs)675
676# If one of the axes is very slow to change between (like SD model677# checkpoint), then make sure it is in the outer iteration of the nested678# `for` loop.679first_axes_processed = 'z'680second_axes_processed = 'y'681if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:682first_axes_processed = 'x'683if y_opt.cost > z_opt.cost:684second_axes_processed = 'y'685else:686second_axes_processed = 'z'687elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:688first_axes_processed = 'y'689if x_opt.cost > z_opt.cost:690second_axes_processed = 'x'691else:692second_axes_processed = 'z'693elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:694first_axes_processed = 'z'695if x_opt.cost > y_opt.cost:696second_axes_processed = 'x'697else:698second_axes_processed = 'y'699
700grid_infotext = [None] * (1 + len(zs))701
702def cell(x, y, z, ix, iy, iz):703if shared.state.interrupted or state.stopping_generation:704return Processed(p, [], p.seed, "")705
706pc = copy(p)707pc.styles = pc.styles[:]708x_opt.apply(pc, x, xs)709y_opt.apply(pc, y, ys)710z_opt.apply(pc, z, zs)711
712xdim = len(xs) if vary_seeds_x else 1713ydim = len(ys) if vary_seeds_y else 1714
715if vary_seeds_x:716pc.seed += ix717if vary_seeds_y:718pc.seed += iy * xdim719if vary_seeds_z:720pc.seed += iz * xdim * ydim721
722try:723res = process_images(pc)724except Exception as e:725errors.display(e, "generating image for xyz plot")726
727res = Processed(p, [], p.seed, "")728
729# Sets subgrid infotexts730subgrid_index = 1 + iz731if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0:732pc.extra_generation_params = copy(pc.extra_generation_params)733pc.extra_generation_params['Script'] = self.title()734
735if x_opt.label != 'Nothing':736pc.extra_generation_params["X Type"] = x_opt.label737pc.extra_generation_params["X Values"] = x_values738if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:739pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])740
741if y_opt.label != 'Nothing':742pc.extra_generation_params["Y Type"] = y_opt.label743pc.extra_generation_params["Y Values"] = y_values744if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:745pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])746
747grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)748
749# Sets main grid infotext750if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0:751pc.extra_generation_params = copy(pc.extra_generation_params)752
753if z_opt.label != 'Nothing':754pc.extra_generation_params["Z Type"] = z_opt.label755pc.extra_generation_params["Z Values"] = z_values756if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:757pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs])758
759grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)760
761return res762
763with SharedSettingsStackHelper():764processed = draw_xyz_grid(765p,766xs=xs,767ys=ys,768zs=zs,769x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],770y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],771z_labels=[z_opt.format_value(p, z_opt, z) for z in zs],772cell=cell,773draw_legend=draw_legend,774include_lone_images=include_lone_images,775include_sub_grids=include_sub_grids,776first_axes_processed=first_axes_processed,777second_axes_processed=second_axes_processed,778margin_size=margin_size779)780
781if not processed.images:782# It broke, no further handling needed.783return processed784
785z_count = len(zs)786
787# Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids)788processed.infotexts[:1+z_count] = grid_infotext[:1+z_count]789
790if not include_lone_images:791# Don't need sub-images anymore, drop from list:792processed.images = processed.images[:z_count+1]793
794if opts.grid_save:795# Auto-save main and sub-grids:796grid_count = z_count + 1 if z_count > 1 else 1797for g in range(grid_count):798# TODO: See previous comment about intentional data misalignment.799adj_g = g-1 if g > 0 else g800images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)801if not include_sub_grids: # if not include_sub_grids then skip saving after the first grid802break803
804if not include_sub_grids:805# Done with sub-grids, drop all related information:806for _ in range(z_count):807del processed.images[1]808del processed.all_prompts[1]809del processed.all_seeds[1]810del processed.infotexts[1]811
812return processed813