stable-diffusion-webui

Форк
0
/
script_callbacks.py 
511 строк · 16.9 Кб
1
import dataclasses
2
import inspect
3
import os
4
from collections import namedtuple
5
from typing import Optional, Any
6

7
from fastapi import FastAPI
8
from gradio import Blocks
9

10
from modules import errors, timer
11

12

13
def report_exception(c, job):
14
    errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
15

16

17
class ImageSaveParams:
18
    def __init__(self, image, p, filename, pnginfo):
19
        self.image = image
20
        """the PIL image itself"""
21

22
        self.p = p
23
        """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
24

25
        self.filename = filename
26
        """name of file that the image would be saved to"""
27

28
        self.pnginfo = pnginfo
29
        """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
30

31

32
class ExtraNoiseParams:
33
    def __init__(self, noise, x, xi):
34
        self.noise = noise
35
        """Random noise generated by the seed"""
36

37
        self.x = x
38
        """Latent representation of the image"""
39

40
        self.xi = xi
41
        """Noisy latent representation of the image"""
42

43

44
class CFGDenoiserParams:
45
    def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None):
46
        self.x = x
47
        """Latent image representation in the process of being denoised"""
48

49
        self.image_cond = image_cond
50
        """Conditioning image"""
51

52
        self.sigma = sigma
53
        """Current sigma noise step value"""
54

55
        self.sampling_step = sampling_step
56
        """Current Sampling step number"""
57

58
        self.total_sampling_steps = total_sampling_steps
59
        """Total number of sampling steps planned"""
60

61
        self.text_cond = text_cond
62
        """ Encoder hidden states of text conditioning from prompt"""
63

64
        self.text_uncond = text_uncond
65
        """ Encoder hidden states of text conditioning from negative prompt"""
66

67
        self.denoiser = denoiser
68
        """Current CFGDenoiser object with processing parameters"""
69

70

71
class CFGDenoisedParams:
72
    def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
73
        self.x = x
74
        """Latent image representation in the process of being denoised"""
75

76
        self.sampling_step = sampling_step
77
        """Current Sampling step number"""
78

79
        self.total_sampling_steps = total_sampling_steps
80
        """Total number of sampling steps planned"""
81

82
        self.inner_model = inner_model
83
        """Inner model reference used for denoising"""
84

85

86
class AfterCFGCallbackParams:
87
    def __init__(self, x, sampling_step, total_sampling_steps):
88
        self.x = x
89
        """Latent image representation in the process of being denoised"""
90

91
        self.sampling_step = sampling_step
92
        """Current Sampling step number"""
93

94
        self.total_sampling_steps = total_sampling_steps
95
        """Total number of sampling steps planned"""
96

97

98
class UiTrainTabParams:
99
    def __init__(self, txt2img_preview_params):
100
        self.txt2img_preview_params = txt2img_preview_params
101

102

103
class ImageGridLoopParams:
104
    def __init__(self, imgs, cols, rows):
105
        self.imgs = imgs
106
        self.cols = cols
107
        self.rows = rows
108

109

110
@dataclasses.dataclass
111
class BeforeTokenCounterParams:
112
    prompt: str
113
    steps: int
114
    styles: list
115

116
    is_positive: bool = True
117

118

119
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
120
callback_map = dict(
121
    callbacks_app_started=[],
122
    callbacks_model_loaded=[],
123
    callbacks_ui_tabs=[],
124
    callbacks_ui_train_tabs=[],
125
    callbacks_ui_settings=[],
126
    callbacks_before_image_saved=[],
127
    callbacks_image_saved=[],
128
    callbacks_extra_noise=[],
129
    callbacks_cfg_denoiser=[],
130
    callbacks_cfg_denoised=[],
131
    callbacks_cfg_after_cfg=[],
132
    callbacks_before_component=[],
133
    callbacks_after_component=[],
134
    callbacks_image_grid=[],
135
    callbacks_infotext_pasted=[],
136
    callbacks_script_unloaded=[],
137
    callbacks_before_ui=[],
138
    callbacks_on_reload=[],
139
    callbacks_list_optimizers=[],
140
    callbacks_list_unets=[],
141
    callbacks_before_token_counter=[],
142
)
143

144

145
def clear_callbacks():
146
    for callback_list in callback_map.values():
147
        callback_list.clear()
148

149

150
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
151
    for c in callback_map['callbacks_app_started']:
152
        try:
153
            c.callback(demo, app)
154
            timer.startup_timer.record(os.path.basename(c.script))
155
        except Exception:
156
            report_exception(c, 'app_started_callback')
157

158

159
def app_reload_callback():
160
    for c in callback_map['callbacks_on_reload']:
161
        try:
162
            c.callback()
163
        except Exception:
164
            report_exception(c, 'callbacks_on_reload')
165

166

167
def model_loaded_callback(sd_model):
168
    for c in callback_map['callbacks_model_loaded']:
169
        try:
170
            c.callback(sd_model)
171
        except Exception:
172
            report_exception(c, 'model_loaded_callback')
173

174

175
def ui_tabs_callback():
176
    res = []
177

178
    for c in callback_map['callbacks_ui_tabs']:
179
        try:
180
            res += c.callback() or []
181
        except Exception:
182
            report_exception(c, 'ui_tabs_callback')
183

184
    return res
185

186

187
def ui_train_tabs_callback(params: UiTrainTabParams):
188
    for c in callback_map['callbacks_ui_train_tabs']:
189
        try:
190
            c.callback(params)
191
        except Exception:
192
            report_exception(c, 'callbacks_ui_train_tabs')
193

194

195
def ui_settings_callback():
196
    for c in callback_map['callbacks_ui_settings']:
197
        try:
198
            c.callback()
199
        except Exception:
200
            report_exception(c, 'ui_settings_callback')
201

202

203
def before_image_saved_callback(params: ImageSaveParams):
204
    for c in callback_map['callbacks_before_image_saved']:
205
        try:
206
            c.callback(params)
207
        except Exception:
208
            report_exception(c, 'before_image_saved_callback')
209

210

211
def image_saved_callback(params: ImageSaveParams):
212
    for c in callback_map['callbacks_image_saved']:
213
        try:
214
            c.callback(params)
215
        except Exception:
216
            report_exception(c, 'image_saved_callback')
217

218

219
def extra_noise_callback(params: ExtraNoiseParams):
220
    for c in callback_map['callbacks_extra_noise']:
221
        try:
222
            c.callback(params)
223
        except Exception:
224
            report_exception(c, 'callbacks_extra_noise')
225

226

227
def cfg_denoiser_callback(params: CFGDenoiserParams):
228
    for c in callback_map['callbacks_cfg_denoiser']:
229
        try:
230
            c.callback(params)
231
        except Exception:
232
            report_exception(c, 'cfg_denoiser_callback')
233

234

235
def cfg_denoised_callback(params: CFGDenoisedParams):
236
    for c in callback_map['callbacks_cfg_denoised']:
237
        try:
238
            c.callback(params)
239
        except Exception:
240
            report_exception(c, 'cfg_denoised_callback')
241

242

243
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
244
    for c in callback_map['callbacks_cfg_after_cfg']:
245
        try:
246
            c.callback(params)
247
        except Exception:
248
            report_exception(c, 'cfg_after_cfg_callback')
249

250

251
def before_component_callback(component, **kwargs):
252
    for c in callback_map['callbacks_before_component']:
253
        try:
254
            c.callback(component, **kwargs)
255
        except Exception:
256
            report_exception(c, 'before_component_callback')
257

258

259
def after_component_callback(component, **kwargs):
260
    for c in callback_map['callbacks_after_component']:
261
        try:
262
            c.callback(component, **kwargs)
263
        except Exception:
264
            report_exception(c, 'after_component_callback')
265

266

267
def image_grid_callback(params: ImageGridLoopParams):
268
    for c in callback_map['callbacks_image_grid']:
269
        try:
270
            c.callback(params)
271
        except Exception:
272
            report_exception(c, 'image_grid')
273

274

275
def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
276
    for c in callback_map['callbacks_infotext_pasted']:
277
        try:
278
            c.callback(infotext, params)
279
        except Exception:
280
            report_exception(c, 'infotext_pasted')
281

282

283
def script_unloaded_callback():
284
    for c in reversed(callback_map['callbacks_script_unloaded']):
285
        try:
286
            c.callback()
287
        except Exception:
288
            report_exception(c, 'script_unloaded')
289

290

291
def before_ui_callback():
292
    for c in reversed(callback_map['callbacks_before_ui']):
293
        try:
294
            c.callback()
295
        except Exception:
296
            report_exception(c, 'before_ui')
297

298

299
def list_optimizers_callback():
300
    res = []
301

302
    for c in callback_map['callbacks_list_optimizers']:
303
        try:
304
            c.callback(res)
305
        except Exception:
306
            report_exception(c, 'list_optimizers')
307

308
    return res
309

310

311
def list_unets_callback():
312
    res = []
313

314
    for c in callback_map['callbacks_list_unets']:
315
        try:
316
            c.callback(res)
317
        except Exception:
318
            report_exception(c, 'list_unets')
319

320
    return res
321

322

323
def before_token_counter_callback(params: BeforeTokenCounterParams):
324
    for c in callback_map['callbacks_before_token_counter']:
325
        try:
326
            c.callback(params)
327
        except Exception:
328
            report_exception(c, 'before_token_counter')
329

330

331
def add_callback(callbacks, fun):
332
    stack = [x for x in inspect.stack() if x.filename != __file__]
333
    filename = stack[0].filename if stack else 'unknown file'
334

335
    callbacks.append(ScriptCallback(filename, fun))
336

337

338
def remove_current_script_callbacks():
339
    stack = [x for x in inspect.stack() if x.filename != __file__]
340
    filename = stack[0].filename if stack else 'unknown file'
341
    if filename == 'unknown file':
342
        return
343
    for callback_list in callback_map.values():
344
        for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
345
            callback_list.remove(callback_to_remove)
346

347

348
def remove_callbacks_for_function(callback_func):
349
    for callback_list in callback_map.values():
350
        for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
351
            callback_list.remove(callback_to_remove)
352

353

354
def on_app_started(callback):
355
    """register a function to be called when the webui started, the gradio `Block` component and
356
    fastapi `FastAPI` object are passed as the arguments"""
357
    add_callback(callback_map['callbacks_app_started'], callback)
358

359

360
def on_before_reload(callback):
361
    """register a function to be called just before the server reloads."""
362
    add_callback(callback_map['callbacks_on_reload'], callback)
363

364

365
def on_model_loaded(callback):
366
    """register a function to be called when the stable diffusion model is created; the model is
367
    passed as an argument; this function is also called when the script is reloaded. """
368
    add_callback(callback_map['callbacks_model_loaded'], callback)
369

370

371
def on_ui_tabs(callback):
372
    """register a function to be called when the UI is creating new tabs.
373
    The function must either return a None, which means no new tabs to be added, or a list, where
374
    each element is a tuple:
375
        (gradio_component, title, elem_id)
376

377
    gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
378
    title is tab text displayed to user in the UI
379
    elem_id is HTML id for the tab
380
    """
381
    add_callback(callback_map['callbacks_ui_tabs'], callback)
382

383

384
def on_ui_train_tabs(callback):
385
    """register a function to be called when the UI is creating new tabs for the train tab.
386
    Create your new tabs with gr.Tab.
387
    """
388
    add_callback(callback_map['callbacks_ui_train_tabs'], callback)
389

390

391
def on_ui_settings(callback):
392
    """register a function to be called before UI settings are populated; add your settings
393
    by using shared.opts.add_option(shared.OptionInfo(...)) """
394
    add_callback(callback_map['callbacks_ui_settings'], callback)
395

396

397
def on_before_image_saved(callback):
398
    """register a function to be called before an image is saved to a file.
399
    The callback is called with one argument:
400
        - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
401
    """
402
    add_callback(callback_map['callbacks_before_image_saved'], callback)
403

404

405
def on_image_saved(callback):
406
    """register a function to be called after an image is saved to a file.
407
    The callback is called with one argument:
408
        - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
409
    """
410
    add_callback(callback_map['callbacks_image_saved'], callback)
411

412

413
def on_extra_noise(callback):
414
    """register a function to be called before adding extra noise in img2img or hires fix;
415
    The callback is called with one argument:
416
        - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
417
    """
418
    add_callback(callback_map['callbacks_extra_noise'], callback)
419

420

421
def on_cfg_denoiser(callback):
422
    """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
423
    The callback is called with one argument:
424
        - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
425
    """
426
    add_callback(callback_map['callbacks_cfg_denoiser'], callback)
427

428

429
def on_cfg_denoised(callback):
430
    """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
431
    The callback is called with one argument:
432
        - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
433
    """
434
    add_callback(callback_map['callbacks_cfg_denoised'], callback)
435

436

437
def on_cfg_after_cfg(callback):
438
    """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
439
    The callback is called with one argument:
440
        - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
441
    """
442
    add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
443

444

445
def on_before_component(callback):
446
    """register a function to be called before a component is created.
447
    The callback is called with arguments:
448
        - component - gradio component that is about to be created.
449
        - **kwargs - args to gradio.components.IOComponent.__init__ function
450

451
    Use elem_id/label fields of kwargs to figure out which component it is.
452
    This can be useful to inject your own components somewhere in the middle of vanilla UI.
453
    """
454
    add_callback(callback_map['callbacks_before_component'], callback)
455

456

457
def on_after_component(callback):
458
    """register a function to be called after a component is created. See on_before_component for more."""
459
    add_callback(callback_map['callbacks_after_component'], callback)
460

461

462
def on_image_grid(callback):
463
    """register a function to be called before making an image grid.
464
    The callback is called with one argument:
465
       - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
466
    """
467
    add_callback(callback_map['callbacks_image_grid'], callback)
468

469

470
def on_infotext_pasted(callback):
471
    """register a function to be called before applying an infotext.
472
    The callback is called with two arguments:
473
       - infotext: str - raw infotext.
474
       - result: dict[str, any] - parsed infotext parameters.
475
    """
476
    add_callback(callback_map['callbacks_infotext_pasted'], callback)
477

478

479
def on_script_unloaded(callback):
480
    """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
481
    the script did should be reverted here"""
482

483
    add_callback(callback_map['callbacks_script_unloaded'], callback)
484

485

486
def on_before_ui(callback):
487
    """register a function to be called before the UI is created."""
488

489
    add_callback(callback_map['callbacks_before_ui'], callback)
490

491

492
def on_list_optimizers(callback):
493
    """register a function to be called when UI is making a list of cross attention optimization options.
494
    The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
495
    to it."""
496

497
    add_callback(callback_map['callbacks_list_optimizers'], callback)
498

499

500
def on_list_unets(callback):
501
    """register a function to be called when UI is making a list of alternative options for unet.
502
    The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
503

504
    add_callback(callback_map['callbacks_list_unets'], callback)
505

506

507
def on_before_token_counter(callback):
508
    """register a function to be called when UI is counting tokens for a prompt.
509
    The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
510

511
    add_callback(callback_map['callbacks_before_token_counter'], callback)
512

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.