stable-diffusion-webui
511 строк · 16.9 Кб
1import dataclasses2import inspect3import os4from collections import namedtuple5from typing import Optional, Any6
7from fastapi import FastAPI8from gradio import Blocks9
10from modules import errors, timer11
12
13def report_exception(c, job):14errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)15
16
17class ImageSaveParams:18def __init__(self, image, p, filename, pnginfo):19self.image = image20"""the PIL image itself"""21
22self.p = p23"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""24
25self.filename = filename26"""name of file that the image would be saved to"""27
28self.pnginfo = pnginfo29"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""30
31
32class ExtraNoiseParams:33def __init__(self, noise, x, xi):34self.noise = noise35"""Random noise generated by the seed"""36
37self.x = x38"""Latent representation of the image"""39
40self.xi = xi41"""Noisy latent representation of the image"""42
43
44class CFGDenoiserParams:45def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None):46self.x = x47"""Latent image representation in the process of being denoised"""48
49self.image_cond = image_cond50"""Conditioning image"""51
52self.sigma = sigma53"""Current sigma noise step value"""54
55self.sampling_step = sampling_step56"""Current Sampling step number"""57
58self.total_sampling_steps = total_sampling_steps59"""Total number of sampling steps planned"""60
61self.text_cond = text_cond62""" Encoder hidden states of text conditioning from prompt"""63
64self.text_uncond = text_uncond65""" Encoder hidden states of text conditioning from negative prompt"""66
67self.denoiser = denoiser68"""Current CFGDenoiser object with processing parameters"""69
70
71class CFGDenoisedParams:72def __init__(self, x, sampling_step, total_sampling_steps, inner_model):73self.x = x74"""Latent image representation in the process of being denoised"""75
76self.sampling_step = sampling_step77"""Current Sampling step number"""78
79self.total_sampling_steps = total_sampling_steps80"""Total number of sampling steps planned"""81
82self.inner_model = inner_model83"""Inner model reference used for denoising"""84
85
86class AfterCFGCallbackParams:87def __init__(self, x, sampling_step, total_sampling_steps):88self.x = x89"""Latent image representation in the process of being denoised"""90
91self.sampling_step = sampling_step92"""Current Sampling step number"""93
94self.total_sampling_steps = total_sampling_steps95"""Total number of sampling steps planned"""96
97
98class UiTrainTabParams:99def __init__(self, txt2img_preview_params):100self.txt2img_preview_params = txt2img_preview_params101
102
103class ImageGridLoopParams:104def __init__(self, imgs, cols, rows):105self.imgs = imgs106self.cols = cols107self.rows = rows108
109
110@dataclasses.dataclass111class BeforeTokenCounterParams:112prompt: str113steps: int114styles: list115
116is_positive: bool = True117
118
119ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])120callback_map = dict(121callbacks_app_started=[],122callbacks_model_loaded=[],123callbacks_ui_tabs=[],124callbacks_ui_train_tabs=[],125callbacks_ui_settings=[],126callbacks_before_image_saved=[],127callbacks_image_saved=[],128callbacks_extra_noise=[],129callbacks_cfg_denoiser=[],130callbacks_cfg_denoised=[],131callbacks_cfg_after_cfg=[],132callbacks_before_component=[],133callbacks_after_component=[],134callbacks_image_grid=[],135callbacks_infotext_pasted=[],136callbacks_script_unloaded=[],137callbacks_before_ui=[],138callbacks_on_reload=[],139callbacks_list_optimizers=[],140callbacks_list_unets=[],141callbacks_before_token_counter=[],142)
143
144
145def clear_callbacks():146for callback_list in callback_map.values():147callback_list.clear()148
149
150def app_started_callback(demo: Optional[Blocks], app: FastAPI):151for c in callback_map['callbacks_app_started']:152try:153c.callback(demo, app)154timer.startup_timer.record(os.path.basename(c.script))155except Exception:156report_exception(c, 'app_started_callback')157
158
159def app_reload_callback():160for c in callback_map['callbacks_on_reload']:161try:162c.callback()163except Exception:164report_exception(c, 'callbacks_on_reload')165
166
167def model_loaded_callback(sd_model):168for c in callback_map['callbacks_model_loaded']:169try:170c.callback(sd_model)171except Exception:172report_exception(c, 'model_loaded_callback')173
174
175def ui_tabs_callback():176res = []177
178for c in callback_map['callbacks_ui_tabs']:179try:180res += c.callback() or []181except Exception:182report_exception(c, 'ui_tabs_callback')183
184return res185
186
187def ui_train_tabs_callback(params: UiTrainTabParams):188for c in callback_map['callbacks_ui_train_tabs']:189try:190c.callback(params)191except Exception:192report_exception(c, 'callbacks_ui_train_tabs')193
194
195def ui_settings_callback():196for c in callback_map['callbacks_ui_settings']:197try:198c.callback()199except Exception:200report_exception(c, 'ui_settings_callback')201
202
203def before_image_saved_callback(params: ImageSaveParams):204for c in callback_map['callbacks_before_image_saved']:205try:206c.callback(params)207except Exception:208report_exception(c, 'before_image_saved_callback')209
210
211def image_saved_callback(params: ImageSaveParams):212for c in callback_map['callbacks_image_saved']:213try:214c.callback(params)215except Exception:216report_exception(c, 'image_saved_callback')217
218
219def extra_noise_callback(params: ExtraNoiseParams):220for c in callback_map['callbacks_extra_noise']:221try:222c.callback(params)223except Exception:224report_exception(c, 'callbacks_extra_noise')225
226
227def cfg_denoiser_callback(params: CFGDenoiserParams):228for c in callback_map['callbacks_cfg_denoiser']:229try:230c.callback(params)231except Exception:232report_exception(c, 'cfg_denoiser_callback')233
234
235def cfg_denoised_callback(params: CFGDenoisedParams):236for c in callback_map['callbacks_cfg_denoised']:237try:238c.callback(params)239except Exception:240report_exception(c, 'cfg_denoised_callback')241
242
243def cfg_after_cfg_callback(params: AfterCFGCallbackParams):244for c in callback_map['callbacks_cfg_after_cfg']:245try:246c.callback(params)247except Exception:248report_exception(c, 'cfg_after_cfg_callback')249
250
251def before_component_callback(component, **kwargs):252for c in callback_map['callbacks_before_component']:253try:254c.callback(component, **kwargs)255except Exception:256report_exception(c, 'before_component_callback')257
258
259def after_component_callback(component, **kwargs):260for c in callback_map['callbacks_after_component']:261try:262c.callback(component, **kwargs)263except Exception:264report_exception(c, 'after_component_callback')265
266
267def image_grid_callback(params: ImageGridLoopParams):268for c in callback_map['callbacks_image_grid']:269try:270c.callback(params)271except Exception:272report_exception(c, 'image_grid')273
274
275def infotext_pasted_callback(infotext: str, params: dict[str, Any]):276for c in callback_map['callbacks_infotext_pasted']:277try:278c.callback(infotext, params)279except Exception:280report_exception(c, 'infotext_pasted')281
282
283def script_unloaded_callback():284for c in reversed(callback_map['callbacks_script_unloaded']):285try:286c.callback()287except Exception:288report_exception(c, 'script_unloaded')289
290
291def before_ui_callback():292for c in reversed(callback_map['callbacks_before_ui']):293try:294c.callback()295except Exception:296report_exception(c, 'before_ui')297
298
299def list_optimizers_callback():300res = []301
302for c in callback_map['callbacks_list_optimizers']:303try:304c.callback(res)305except Exception:306report_exception(c, 'list_optimizers')307
308return res309
310
311def list_unets_callback():312res = []313
314for c in callback_map['callbacks_list_unets']:315try:316c.callback(res)317except Exception:318report_exception(c, 'list_unets')319
320return res321
322
323def before_token_counter_callback(params: BeforeTokenCounterParams):324for c in callback_map['callbacks_before_token_counter']:325try:326c.callback(params)327except Exception:328report_exception(c, 'before_token_counter')329
330
331def add_callback(callbacks, fun):332stack = [x for x in inspect.stack() if x.filename != __file__]333filename = stack[0].filename if stack else 'unknown file'334
335callbacks.append(ScriptCallback(filename, fun))336
337
338def remove_current_script_callbacks():339stack = [x for x in inspect.stack() if x.filename != __file__]340filename = stack[0].filename if stack else 'unknown file'341if filename == 'unknown file':342return343for callback_list in callback_map.values():344for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:345callback_list.remove(callback_to_remove)346
347
348def remove_callbacks_for_function(callback_func):349for callback_list in callback_map.values():350for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:351callback_list.remove(callback_to_remove)352
353
354def on_app_started(callback):355"""register a function to be called when the webui started, the gradio `Block` component and356fastapi `FastAPI` object are passed as the arguments"""
357add_callback(callback_map['callbacks_app_started'], callback)358
359
360def on_before_reload(callback):361"""register a function to be called just before the server reloads."""362add_callback(callback_map['callbacks_on_reload'], callback)363
364
365def on_model_loaded(callback):366"""register a function to be called when the stable diffusion model is created; the model is367passed as an argument; this function is also called when the script is reloaded. """
368add_callback(callback_map['callbacks_model_loaded'], callback)369
370
371def on_ui_tabs(callback):372"""register a function to be called when the UI is creating new tabs.373The function must either return a None, which means no new tabs to be added, or a list, where
374each element is a tuple:
375(gradio_component, title, elem_id)
376
377gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
378title is tab text displayed to user in the UI
379elem_id is HTML id for the tab
380"""
381add_callback(callback_map['callbacks_ui_tabs'], callback)382
383
384def on_ui_train_tabs(callback):385"""register a function to be called when the UI is creating new tabs for the train tab.386Create your new tabs with gr.Tab.
387"""
388add_callback(callback_map['callbacks_ui_train_tabs'], callback)389
390
391def on_ui_settings(callback):392"""register a function to be called before UI settings are populated; add your settings393by using shared.opts.add_option(shared.OptionInfo(...)) """
394add_callback(callback_map['callbacks_ui_settings'], callback)395
396
397def on_before_image_saved(callback):398"""register a function to be called before an image is saved to a file.399The callback is called with one argument:
400- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
401"""
402add_callback(callback_map['callbacks_before_image_saved'], callback)403
404
405def on_image_saved(callback):406"""register a function to be called after an image is saved to a file.407The callback is called with one argument:
408- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
409"""
410add_callback(callback_map['callbacks_image_saved'], callback)411
412
413def on_extra_noise(callback):414"""register a function to be called before adding extra noise in img2img or hires fix;415The callback is called with one argument:
416- params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
417"""
418add_callback(callback_map['callbacks_extra_noise'], callback)419
420
421def on_cfg_denoiser(callback):422"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.423The callback is called with one argument:
424- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
425"""
426add_callback(callback_map['callbacks_cfg_denoiser'], callback)427
428
429def on_cfg_denoised(callback):430"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.431The callback is called with one argument:
432- params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
433"""
434add_callback(callback_map['callbacks_cfg_denoised'], callback)435
436
437def on_cfg_after_cfg(callback):438"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.439The callback is called with one argument:
440- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
441"""
442add_callback(callback_map['callbacks_cfg_after_cfg'], callback)443
444
445def on_before_component(callback):446"""register a function to be called before a component is created.447The callback is called with arguments:
448- component - gradio component that is about to be created.
449- **kwargs - args to gradio.components.IOComponent.__init__ function
450
451Use elem_id/label fields of kwargs to figure out which component it is.
452This can be useful to inject your own components somewhere in the middle of vanilla UI.
453"""
454add_callback(callback_map['callbacks_before_component'], callback)455
456
457def on_after_component(callback):458"""register a function to be called after a component is created. See on_before_component for more."""459add_callback(callback_map['callbacks_after_component'], callback)460
461
462def on_image_grid(callback):463"""register a function to be called before making an image grid.464The callback is called with one argument:
465- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
466"""
467add_callback(callback_map['callbacks_image_grid'], callback)468
469
470def on_infotext_pasted(callback):471"""register a function to be called before applying an infotext.472The callback is called with two arguments:
473- infotext: str - raw infotext.
474- result: dict[str, any] - parsed infotext parameters.
475"""
476add_callback(callback_map['callbacks_infotext_pasted'], callback)477
478
479def on_script_unloaded(callback):480"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that481the script did should be reverted here"""
482
483add_callback(callback_map['callbacks_script_unloaded'], callback)484
485
486def on_before_ui(callback):487"""register a function to be called before the UI is created."""488
489add_callback(callback_map['callbacks_before_ui'], callback)490
491
492def on_list_optimizers(callback):493"""register a function to be called when UI is making a list of cross attention optimization options.494The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
495to it."""
496
497add_callback(callback_map['callbacks_list_optimizers'], callback)498
499
500def on_list_unets(callback):501"""register a function to be called when UI is making a list of alternative options for unet.502The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
503
504add_callback(callback_map['callbacks_list_unets'], callback)505
506
507def on_before_token_counter(callback):508"""register a function to be called when UI is counting tokens for a prompt.509The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
510
511add_callback(callback_map['callbacks_before_token_counter'], callback)512