stable-diffusion-webui

Форк
0
984 строки · 38.0 Кб
1
import os
2
import re
3
import sys
4
import inspect
5
from collections import namedtuple
6
from dataclasses import dataclass
7

8
import gradio as gr
9

10
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
11

12
AlwaysVisible = object()
13

14
class MaskBlendArgs:
15
    def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
16
        self.current_latent = current_latent
17
        self.nmask = nmask
18
        self.init_latent = init_latent
19
        self.mask = mask
20
        self.blended_latent = blended_latent
21

22
        self.denoiser = denoiser
23
        self.is_final_blend = denoiser is None
24
        self.sigma = sigma
25

26
class PostSampleArgs:
27
    def __init__(self, samples):
28
        self.samples = samples
29

30
class PostprocessImageArgs:
31
    def __init__(self, image):
32
        self.image = image
33

34
class PostProcessMaskOverlayArgs:
35
    def __init__(self, index, mask_for_overlay, overlay_image):
36
        self.index = index
37
        self.mask_for_overlay = mask_for_overlay
38
        self.overlay_image = overlay_image
39

40
class PostprocessBatchListArgs:
41
    def __init__(self, images):
42
        self.images = images
43

44

45
@dataclass
46
class OnComponent:
47
    component: gr.blocks.Block
48

49

50
class Script:
51
    name = None
52
    """script's internal name derived from title"""
53

54
    section = None
55
    """name of UI section that the script's controls will be placed into"""
56

57
    filename = None
58
    args_from = None
59
    args_to = None
60
    alwayson = False
61

62
    is_txt2img = False
63
    is_img2img = False
64
    tabname = None
65

66
    group = None
67
    """A gr.Group component that has all script's UI inside it."""
68

69
    create_group = True
70
    """If False, for alwayson scripts, a group component will not be created."""
71

72
    infotext_fields = None
73
    """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
74
    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
75
    """
76

77
    paste_field_names = None
78
    """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
79
    various "Send to <X>" buttons when clicked
80
    """
81

82
    api_info = None
83
    """Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
84

85
    on_before_component_elem_id = None
86
    """list of callbacks to be called before a component with an elem_id is created"""
87

88
    on_after_component_elem_id = None
89
    """list of callbacks to be called after a component with an elem_id is created"""
90

91
    setup_for_ui_only = False
92
    """If true, the script setup will only be run in Gradio UI, not in API"""
93

94
    controls = None
95
    """A list of controls retured by the ui()."""
96

97
    def title(self):
98
        """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
99

100
        raise NotImplementedError()
101

102
    def ui(self, is_img2img):
103
        """this function should create gradio UI elements. See https://gradio.app/docs/#components
104
        The return value should be an array of all components that are used in processing.
105
        Values of those returned components will be passed to run() and process() functions.
106
        """
107

108
        pass
109

110
    def show(self, is_img2img):
111
        """
112
        is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
113

114
        This function should return:
115
         - False if the script should not be shown in UI at all
116
         - True if the script should be shown in UI if it's selected in the scripts dropdown
117
         - script.AlwaysVisible if the script should be shown in UI at all times
118
         """
119

120
        return True
121

122
    def run(self, p, *args):
123
        """
124
        This function is called if the script has been selected in the script dropdown.
125
        It must do all processing and return the Processed object with results, same as
126
        one returned by processing.process_images.
127

128
        Usually the processing is done by calling the processing.process_images function.
129

130
        args contains all values returned by components from ui()
131
        """
132

133
        pass
134

135
    def setup(self, p, *args):
136
        """For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
137
        args contains all values returned by components from ui().
138
        """
139
        pass
140

141

142
    def before_process(self, p, *args):
143
        """
144
        This function is called very early during processing begins for AlwaysVisible scripts.
145
        You can modify the processing object (p) here, inject hooks, etc.
146
        args contains all values returned by components from ui()
147
        """
148

149
        pass
150

151
    def process(self, p, *args):
152
        """
153
        This function is called before processing begins for AlwaysVisible scripts.
154
        You can modify the processing object (p) here, inject hooks, etc.
155
        args contains all values returned by components from ui()
156
        """
157

158
        pass
159

160
    def before_process_batch(self, p, *args, **kwargs):
161
        """
162
        Called before extra networks are parsed from the prompt, so you can add
163
        new extra network keywords to the prompt with this callback.
164

165
        **kwargs will have those items:
166
          - batch_number - index of current batch, from 0 to number of batches-1
167
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
168
          - seeds - list of seeds for current batch
169
          - subseeds - list of subseeds for current batch
170
        """
171

172
        pass
173

174
    def after_extra_networks_activate(self, p, *args, **kwargs):
175
        """
176
        Called after extra networks activation, before conds calculation
177
        allow modification of the network after extra networks activation been applied
178
        won't be call if p.disable_extra_networks
179

180
        **kwargs will have those items:
181
          - batch_number - index of current batch, from 0 to number of batches-1
182
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
183
          - seeds - list of seeds for current batch
184
          - subseeds - list of subseeds for current batch
185
          - extra_network_data - list of ExtraNetworkParams for current stage
186
        """
187
        pass
188

189
    def process_batch(self, p, *args, **kwargs):
190
        """
191
        Same as process(), but called for every batch.
192

193
        **kwargs will have those items:
194
          - batch_number - index of current batch, from 0 to number of batches-1
195
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
196
          - seeds - list of seeds for current batch
197
          - subseeds - list of subseeds for current batch
198
        """
199

200
        pass
201

202
    def postprocess_batch(self, p, *args, **kwargs):
203
        """
204
        Same as process_batch(), but called for every batch after it has been generated.
205

206
        **kwargs will have same items as process_batch, and also:
207
          - batch_number - index of current batch, from 0 to number of batches-1
208
          - images - torch tensor with all generated images, with values ranging from 0 to 1;
209
        """
210

211
        pass
212

213
    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
214
        """
215
        Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
216
        This is useful when you want to update the entire batch instead of individual images.
217

218
        You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
219
        If the number of images is different from the batch size when returning,
220
        then the script has the responsibility to also update the following attributes in the processing object (p):
221
          - p.prompts
222
          - p.negative_prompts
223
          - p.seeds
224
          - p.subseeds
225

226
        **kwargs will have same items as process_batch, and also:
227
          - batch_number - index of current batch, from 0 to number of batches-1
228
        """
229

230
        pass
231

232
    def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
233
        """
234
        Called in inpainting mode when the original content is blended with the inpainted content.
235
        This is called at every step in the denoising process and once at the end.
236
        If is_final_blend is true, this is called for the final blending stage.
237
        Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
238
        """
239

240
        pass
241

242
    def post_sample(self, p, ps: PostSampleArgs, *args):
243
        """
244
        Called after the samples have been generated,
245
        but before they have been decoded by the VAE, if applicable.
246
        Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
247
        """
248

249
        pass
250

251
    def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
252
        """
253
        Called for every image after it has been generated.
254
        """
255

256
        pass
257

258
    def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
259
        """
260
        Called for every image after it has been generated.
261
        """
262

263
        pass
264

265
    def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args):
266
        """
267
        Called for every image after it has been generated.
268
        Same as postprocess_image but after inpaint_full_res composite
269
        So that it operates on the full image instead of the inpaint_full_res crop region.
270
        """
271

272
        pass
273

274
    def postprocess(self, p, processed, *args):
275
        """
276
        This function is called after processing ends for AlwaysVisible scripts.
277
        args contains all values returned by components from ui()
278
        """
279

280
        pass
281

282
    def before_component(self, component, **kwargs):
283
        """
284
        Called before a component is created.
285
        Use elem_id/label fields of kwargs to figure out which component it is.
286
        This can be useful to inject your own components somewhere in the middle of vanilla UI.
287
        You can return created components in the ui() function to add them to the list of arguments for your processing functions
288
        """
289

290
        pass
291

292
    def after_component(self, component, **kwargs):
293
        """
294
        Called after a component is created. Same as above.
295
        """
296

297
        pass
298

299
    def on_before_component(self, callback, *, elem_id):
300
        """
301
        Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.
302

303
        May be called in show() or ui() - but it may be too late in latter as some components may already be created.
304

305
        This function is an alternative to before_component in that it also cllows to run before a component is created, but
306
        it doesn't require to be called for every created component - just for the one you need.
307
        """
308
        if self.on_before_component_elem_id is None:
309
            self.on_before_component_elem_id = []
310

311
        self.on_before_component_elem_id.append((elem_id, callback))
312

313
    def on_after_component(self, callback, *, elem_id):
314
        """
315
        Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.
316
        """
317
        if self.on_after_component_elem_id is None:
318
            self.on_after_component_elem_id = []
319

320
        self.on_after_component_elem_id.append((elem_id, callback))
321

322
    def describe(self):
323
        """unused"""
324
        return ""
325

326
    def elem_id(self, item_id):
327
        """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
328

329
        need_tabname = self.show(True) == self.show(False)
330
        tabkind = 'img2img' if self.is_img2img else 'txt2img'
331
        tabname = f"{tabkind}_" if need_tabname else ""
332
        title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
333

334
        return f'script_{tabname}{title}_{item_id}'
335

336
    def before_hr(self, p, *args):
337
        """
338
        This function is called before hires fix start.
339
        """
340
        pass
341

342

343
class ScriptBuiltinUI(Script):
344
    setup_for_ui_only = True
345

346
    def elem_id(self, item_id):
347
        """helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
348

349
        need_tabname = self.show(True) == self.show(False)
350
        tabname = ('img2img' if self.is_img2img else 'txt2img') + "_" if need_tabname else ""
351

352
        return f'{tabname}{item_id}'
353

354

355
current_basedir = paths.script_path
356

357

358
def basedir():
359
    """returns the base directory for the current script. For scripts in the main scripts directory,
360
    this is the main directory (where webui.py resides), and for scripts in extensions directory
361
    (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
362
    """
363
    return current_basedir
364

365

366
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
367

368
scripts_data = []
369
postprocessing_scripts_data = []
370
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
371

372
def topological_sort(dependencies):
373
    """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
374
    Ignores errors relating to missing dependeencies or circular dependencies
375
    """
376

377
    visited = {}
378
    result = []
379

380
    def inner(name):
381
        visited[name] = True
382

383
        for dep in dependencies.get(name, []):
384
            if dep in dependencies and dep not in visited:
385
                inner(dep)
386

387
        result.append(name)
388

389
    for depname in dependencies:
390
        if depname not in visited:
391
            inner(depname)
392

393
    return result
394

395

396
@dataclass
397
class ScriptWithDependencies:
398
    script_canonical_name: str
399
    file: ScriptFile
400
    requires: list
401
    load_before: list
402
    load_after: list
403

404

405
def list_scripts(scriptdirname, extension, *, include_extensions=True):
406
    scripts = {}
407

408
    loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}
409
    loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}
410

411
    # build script dependency map
412
    root_script_basedir = os.path.join(paths.script_path, scriptdirname)
413
    if os.path.exists(root_script_basedir):
414
        for filename in sorted(os.listdir(root_script_basedir)):
415
            if not os.path.isfile(os.path.join(root_script_basedir, filename)):
416
                continue
417

418
            if os.path.splitext(filename)[1].lower() != extension:
419
                continue
420

421
            script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))
422
            scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])
423

424
    if include_extensions:
425
        for ext in extensions.active():
426
            extension_scripts_list = ext.list_files(scriptdirname, extension)
427
            for extension_script in extension_scripts_list:
428
                if not os.path.isfile(extension_script.path):
429
                    continue
430

431
                script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename
432
                relative_path = scriptdirname + "/" + extension_script.filename
433

434
                script = ScriptWithDependencies(
435
                    script_canonical_name=script_canonical_name,
436
                    file=extension_script,
437
                    requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname),
438
                    load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname),
439
                    load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname),
440
                )
441

442
                scripts[script_canonical_name] = script
443
                loaded_extensions_scripts[ext.canonical_name].append(script)
444

445
    for script_canonical_name, script in scripts.items():
446
        # load before requires inverse dependency
447
        # in this case, append the script name into the load_after list of the specified script
448
        for load_before in script.load_before:
449
            # if this requires an individual script to be loaded before
450
            other_script = scripts.get(load_before)
451
            if other_script:
452
                other_script.load_after.append(script_canonical_name)
453

454
            # if this requires an extension
455
            other_extension_scripts = loaded_extensions_scripts.get(load_before)
456
            if other_extension_scripts:
457
                for other_script in other_extension_scripts:
458
                    other_script.load_after.append(script_canonical_name)
459

460
        # if After mentions an extension, remove it and instead add all of its scripts
461
        for load_after in list(script.load_after):
462
            if load_after not in scripts and load_after in loaded_extensions_scripts:
463
                script.load_after.remove(load_after)
464

465
                for other_script in loaded_extensions_scripts.get(load_after, []):
466
                    script.load_after.append(other_script.script_canonical_name)
467

468
    dependencies = {}
469

470
    for script_canonical_name, script in scripts.items():
471
        for required_script in script.requires:
472
            if required_script not in scripts and required_script not in loaded_extensions:
473
                errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False)
474

475
        dependencies[script_canonical_name] = script.load_after
476

477
    ordered_scripts = topological_sort(dependencies)
478
    scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]
479

480
    return scripts_list
481

482

483
def list_files_with_name(filename):
484
    res = []
485

486
    dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
487

488
    for dirpath in dirs:
489
        if not os.path.isdir(dirpath):
490
            continue
491

492
        path = os.path.join(dirpath, filename)
493
        if os.path.isfile(path):
494
            res.append(path)
495

496
    return res
497

498

499
def load_scripts():
500
    global current_basedir
501
    scripts_data.clear()
502
    postprocessing_scripts_data.clear()
503
    script_callbacks.clear_callbacks()
504

505
    scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
506

507
    syspath = sys.path
508

509
    def register_scripts_from_module(module):
510
        for script_class in module.__dict__.values():
511
            if not inspect.isclass(script_class):
512
                continue
513

514
            if issubclass(script_class, Script):
515
                scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
516
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
517
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
518

519
    # here the scripts_list is already ordered
520
    # processing_script is not considered though
521
    for scriptfile in scripts_list:
522
        try:
523
            if scriptfile.basedir != paths.script_path:
524
                sys.path = [scriptfile.basedir] + sys.path
525
            current_basedir = scriptfile.basedir
526

527
            script_module = script_loading.load_module(scriptfile.path)
528
            register_scripts_from_module(script_module)
529

530
        except Exception:
531
            errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
532

533
        finally:
534
            sys.path = syspath
535
            current_basedir = paths.script_path
536
            timer.startup_timer.record(scriptfile.filename)
537

538
    global scripts_txt2img, scripts_img2img, scripts_postproc
539

540
    scripts_txt2img = ScriptRunner()
541
    scripts_img2img = ScriptRunner()
542
    scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
543

544

545
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
546
    try:
547
        return func(*args, **kwargs)
548
    except Exception:
549
        errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
550

551
    return default
552

553

554
class ScriptRunner:
555
    def __init__(self):
556
        self.scripts = []
557
        self.selectable_scripts = []
558
        self.alwayson_scripts = []
559
        self.titles = []
560
        self.title_map = {}
561
        self.infotext_fields = []
562
        self.paste_field_names = []
563
        self.inputs = [None]
564

565
        self.on_before_component_elem_id = {}
566
        """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
567

568
        self.on_after_component_elem_id = {}
569
        """dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""
570

571
    def initialize_scripts(self, is_img2img):
572
        from modules import scripts_auto_postprocessing
573

574
        self.scripts.clear()
575
        self.alwayson_scripts.clear()
576
        self.selectable_scripts.clear()
577

578
        auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
579

580
        for script_data in auto_processing_scripts + scripts_data:
581
            try:
582
                script = script_data.script_class()
583
            except Exception:
584
                errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True)
585
                continue
586

587
            script.filename = script_data.path
588
            script.is_txt2img = not is_img2img
589
            script.is_img2img = is_img2img
590
            script.tabname = "img2img" if is_img2img else "txt2img"
591

592
            visibility = script.show(script.is_img2img)
593

594
            if visibility == AlwaysVisible:
595
                self.scripts.append(script)
596
                self.alwayson_scripts.append(script)
597
                script.alwayson = True
598

599
            elif visibility:
600
                self.scripts.append(script)
601
                self.selectable_scripts.append(script)
602

603
        self.apply_on_before_component_callbacks()
604

605
    def apply_on_before_component_callbacks(self):
606
        for script in self.scripts:
607
            on_before = script.on_before_component_elem_id or []
608
            on_after = script.on_after_component_elem_id or []
609

610
            for elem_id, callback in on_before:
611
                if elem_id not in self.on_before_component_elem_id:
612
                    self.on_before_component_elem_id[elem_id] = []
613

614
                self.on_before_component_elem_id[elem_id].append((callback, script))
615

616
            for elem_id, callback in on_after:
617
                if elem_id not in self.on_after_component_elem_id:
618
                    self.on_after_component_elem_id[elem_id] = []
619

620
                self.on_after_component_elem_id[elem_id].append((callback, script))
621

622
            on_before.clear()
623
            on_after.clear()
624

625
    def create_script_ui(self, script):
626

627
        script.args_from = len(self.inputs)
628
        script.args_to = len(self.inputs)
629

630
        try:
631
            self.create_script_ui_inner(script)
632
        except Exception:
633
            errors.report(f"Error creating UI for {script.name}: ", exc_info=True)
634

635
    def create_script_ui_inner(self, script):
636
        import modules.api.models as api_models
637

638
        controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
639
        script.controls = controls
640

641
        if controls is None:
642
            return
643

644
        script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
645

646
        api_args = []
647

648
        for control in controls:
649
            control.custom_script_source = os.path.basename(script.filename)
650

651
            arg_info = api_models.ScriptArg(label=control.label or "")
652

653
            for field in ("value", "minimum", "maximum", "step"):
654
                v = getattr(control, field, None)
655
                if v is not None:
656
                    setattr(arg_info, field, v)
657

658
            choices = getattr(control, 'choices', None)  # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string
659
            if choices is not None:
660
                arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices]
661

662
            api_args.append(arg_info)
663

664
        script.api_info = api_models.ScriptInfo(
665
            name=script.name,
666
            is_img2img=script.is_img2img,
667
            is_alwayson=script.alwayson,
668
            args=api_args,
669
        )
670

671
        if script.infotext_fields is not None:
672
            self.infotext_fields += script.infotext_fields
673

674
        if script.paste_field_names is not None:
675
            self.paste_field_names += script.paste_field_names
676

677
        self.inputs += controls
678
        script.args_to = len(self.inputs)
679

680
    def setup_ui_for_section(self, section, scriptlist=None):
681
        if scriptlist is None:
682
            scriptlist = self.alwayson_scripts
683

684
        for script in scriptlist:
685
            if script.alwayson and script.section != section:
686
                continue
687

688
            if script.create_group:
689
                with gr.Group(visible=script.alwayson) as group:
690
                    self.create_script_ui(script)
691

692
                script.group = group
693
            else:
694
                self.create_script_ui(script)
695

696
    def prepare_ui(self):
697
        self.inputs = [None]
698

699
    def setup_ui(self):
700
        all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts]
701
        self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}
702
        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
703

704
        self.setup_ui_for_section(None)
705

706
        dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
707
        self.inputs[0] = dropdown
708

709
        self.setup_ui_for_section(None, self.selectable_scripts)
710

711
        def select_script(script_index):
712
            if script_index is None:
713
                script_index = 0
714
            selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
715

716
            return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
717

718
        def init_field(title):
719
            """called when an initial value is set from ui-config.json to show script's UI components"""
720

721
            if title == 'None':
722
                return
723

724
            script_index = self.titles.index(title)
725
            self.selectable_scripts[script_index].group.visible = True
726

727
        dropdown.init_field = init_field
728

729
        dropdown.change(
730
            fn=select_script,
731
            inputs=[dropdown],
732
            outputs=[script.group for script in self.selectable_scripts]
733
        )
734

735
        self.script_load_ctr = 0
736

737
        def onload_script_visibility(params):
738
            title = params.get('Script', None)
739
            if title:
740
                title_index = self.titles.index(title)
741
                visibility = title_index == self.script_load_ctr
742
                self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
743
                return gr.update(visible=visibility)
744
            else:
745
                return gr.update(visible=False)
746

747
        self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
748
        self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
749

750
        self.apply_on_before_component_callbacks()
751

752
        return self.inputs
753

754
    def run(self, p, *args):
755
        script_index = args[0]
756

757
        if script_index == 0 or script_index is None:
758
            return None
759

760
        script = self.selectable_scripts[script_index-1]
761

762
        if script is None:
763
            return None
764

765
        script_args = args[script.args_from:script.args_to]
766
        processed = script.run(p, *script_args)
767

768
        shared.total_tqdm.clear()
769

770
        return processed
771

772
    def before_process(self, p):
773
        for script in self.alwayson_scripts:
774
            try:
775
                script_args = p.script_args[script.args_from:script.args_to]
776
                script.before_process(p, *script_args)
777
            except Exception:
778
                errors.report(f"Error running before_process: {script.filename}", exc_info=True)
779

780
    def process(self, p):
781
        for script in self.alwayson_scripts:
782
            try:
783
                script_args = p.script_args[script.args_from:script.args_to]
784
                script.process(p, *script_args)
785
            except Exception:
786
                errors.report(f"Error running process: {script.filename}", exc_info=True)
787

788
    def before_process_batch(self, p, **kwargs):
789
        for script in self.alwayson_scripts:
790
            try:
791
                script_args = p.script_args[script.args_from:script.args_to]
792
                script.before_process_batch(p, *script_args, **kwargs)
793
            except Exception:
794
                errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
795

796
    def after_extra_networks_activate(self, p, **kwargs):
797
        for script in self.alwayson_scripts:
798
            try:
799
                script_args = p.script_args[script.args_from:script.args_to]
800
                script.after_extra_networks_activate(p, *script_args, **kwargs)
801
            except Exception:
802
                errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
803

804
    def process_batch(self, p, **kwargs):
805
        for script in self.alwayson_scripts:
806
            try:
807
                script_args = p.script_args[script.args_from:script.args_to]
808
                script.process_batch(p, *script_args, **kwargs)
809
            except Exception:
810
                errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
811

812
    def postprocess(self, p, processed):
813
        for script in self.alwayson_scripts:
814
            try:
815
                script_args = p.script_args[script.args_from:script.args_to]
816
                script.postprocess(p, processed, *script_args)
817
            except Exception:
818
                errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
819

820
    def postprocess_batch(self, p, images, **kwargs):
821
        for script in self.alwayson_scripts:
822
            try:
823
                script_args = p.script_args[script.args_from:script.args_to]
824
                script.postprocess_batch(p, *script_args, images=images, **kwargs)
825
            except Exception:
826
                errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
827

828
    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
829
        for script in self.alwayson_scripts:
830
            try:
831
                script_args = p.script_args[script.args_from:script.args_to]
832
                script.postprocess_batch_list(p, pp, *script_args, **kwargs)
833
            except Exception:
834
                errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
835

836
    def post_sample(self, p, ps: PostSampleArgs):
837
        for script in self.alwayson_scripts:
838
            try:
839
                script_args = p.script_args[script.args_from:script.args_to]
840
                script.post_sample(p, ps, *script_args)
841
            except Exception:
842
                errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
843

844
    def on_mask_blend(self, p, mba: MaskBlendArgs):
845
        for script in self.alwayson_scripts:
846
            try:
847
                script_args = p.script_args[script.args_from:script.args_to]
848
                script.on_mask_blend(p, mba, *script_args)
849
            except Exception:
850
                errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
851

852
    def postprocess_image(self, p, pp: PostprocessImageArgs):
853
        for script in self.alwayson_scripts:
854
            try:
855
                script_args = p.script_args[script.args_from:script.args_to]
856
                script.postprocess_image(p, pp, *script_args)
857
            except Exception:
858
                errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
859

860
    def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
861
        for script in self.alwayson_scripts:
862
            try:
863
                script_args = p.script_args[script.args_from:script.args_to]
864
                script.postprocess_maskoverlay(p, ppmo, *script_args)
865
            except Exception:
866
                errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
867

868
    def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
869
        for script in self.alwayson_scripts:
870
            try:
871
                script_args = p.script_args[script.args_from:script.args_to]
872
                script.postprocess_image_after_composite(p, pp, *script_args)
873
            except Exception:
874
                errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True)
875

876
    def before_component(self, component, **kwargs):
877
        for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
878
            try:
879
                callback(OnComponent(component=component))
880
            except Exception:
881
                errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
882

883
        for script in self.scripts:
884
            try:
885
                script.before_component(component, **kwargs)
886
            except Exception:
887
                errors.report(f"Error running before_component: {script.filename}", exc_info=True)
888

889
    def after_component(self, component, **kwargs):
890
        for callback, script in self.on_after_component_elem_id.get(component.elem_id, []):
891
            try:
892
                callback(OnComponent(component=component))
893
            except Exception:
894
                errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
895

896
        for script in self.scripts:
897
            try:
898
                script.after_component(component, **kwargs)
899
            except Exception:
900
                errors.report(f"Error running after_component: {script.filename}", exc_info=True)
901

902
    def script(self, title):
903
        return self.title_map.get(title.lower())
904

905
    def reload_sources(self, cache):
906
        for si, script in list(enumerate(self.scripts)):
907
            args_from = script.args_from
908
            args_to = script.args_to
909
            filename = script.filename
910

911
            module = cache.get(filename, None)
912
            if module is None:
913
                module = script_loading.load_module(script.filename)
914
                cache[filename] = module
915

916
            for script_class in module.__dict__.values():
917
                if type(script_class) == type and issubclass(script_class, Script):
918
                    self.scripts[si] = script_class()
919
                    self.scripts[si].filename = filename
920
                    self.scripts[si].args_from = args_from
921
                    self.scripts[si].args_to = args_to
922

923
    def before_hr(self, p):
924
        for script in self.alwayson_scripts:
925
            try:
926
                script_args = p.script_args[script.args_from:script.args_to]
927
                script.before_hr(p, *script_args)
928
            except Exception:
929
                errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
930

931
    def setup_scrips(self, p, *, is_ui=True):
932
        for script in self.alwayson_scripts:
933
            if not is_ui and script.setup_for_ui_only:
934
                continue
935

936
            try:
937
                script_args = p.script_args[script.args_from:script.args_to]
938
                script.setup(p, *script_args)
939
            except Exception:
940
                errors.report(f"Error running setup: {script.filename}", exc_info=True)
941

942
    def set_named_arg(self, args, script_name, arg_elem_id, value, fuzzy=False):
943
        """Locate an arg of a specific script in script_args and set its value
944
        Args:
945
            args: all script args of process p, p.script_args
946
            script_name: the name target script name to
947
            arg_elem_id: the elem_id of the target arg
948
            value: the value to set
949
            fuzzy: if True, arg_elem_id can be a substring of the control.elem_id else exact match
950
        Returns:
951
            Updated script args
952
        when script_name in not found or arg_elem_id is not found in script controls, raise RuntimeError
953
        """
954
        script = next((x for x in self.scripts if x.name == script_name), None)
955
        if script is None:
956
            raise RuntimeError(f"script {script_name} not found")
957

958
        for i, control in enumerate(script.controls):
959
            if arg_elem_id in control.elem_id if fuzzy else arg_elem_id == control.elem_id:
960
                index = script.args_from + i
961

962
                if isinstance(args, tuple):
963
                    return args[:index] + (value,) + args[index + 1:]
964
                elif isinstance(args, list):
965
                    args[index] = value
966
                    return args
967
                else:
968
                    raise RuntimeError(f"args is not a list or tuple, but {type(args)}")
969
        raise RuntimeError(f"arg_elem_id {arg_elem_id} not found in script {script_name}")
970

971

972
scripts_txt2img: ScriptRunner = None
973
scripts_img2img: ScriptRunner = None
974
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
975
scripts_current: ScriptRunner = None
976

977

978
def reload_script_body_only():
979
    cache = {}
980
    scripts_txt2img.reload_sources(cache)
981
    scripts_img2img.reload_sources(cache)
982

983

984
reload_scripts = load_scripts  # compatibility alias
985

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.