stable-diffusion-webui

Форк
0
168 строк · 5.7 Кб
1
import importlib
2
import logging
3
import os
4
import sys
5
import warnings
6
from threading import Thread
7

8
from modules.timer import startup_timer
9

10

11
def imports():
12
    logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR)  # sshh...
13
    logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
14

15
    import torch  # noqa: F401
16
    startup_timer.record("import torch")
17
    import pytorch_lightning  # noqa: F401
18
    startup_timer.record("import torch")
19
    warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
20
    warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
21

22
    os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
23
    import gradio  # noqa: F401
24
    startup_timer.record("import gradio")
25

26
    from modules import paths, timer, import_hook, errors  # noqa: F401
27
    startup_timer.record("setup paths")
28

29
    import ldm.modules.encoders.modules  # noqa: F401
30
    startup_timer.record("import ldm")
31

32
    import sgm.modules.encoders.modules  # noqa: F401
33
    startup_timer.record("import sgm")
34

35
    from modules import shared_init
36
    shared_init.initialize()
37
    startup_timer.record("initialize shared")
38

39
    from modules import processing, gradio_extensons, ui  # noqa: F401
40
    startup_timer.record("other imports")
41

42

43
def check_versions():
44
    from modules.shared_cmd_options import cmd_opts
45

46
    if not cmd_opts.skip_version_check:
47
        from modules import errors
48
        errors.check_versions()
49

50

51
def initialize():
52
    from modules import initialize_util
53
    initialize_util.fix_torch_version()
54
    initialize_util.fix_asyncio_event_loop_policy()
55
    initialize_util.validate_tls_options()
56
    initialize_util.configure_sigint_handler()
57
    initialize_util.configure_opts_onchange()
58

59
    from modules import sd_models
60
    sd_models.setup_model()
61
    startup_timer.record("setup SD model")
62

63
    from modules.shared_cmd_options import cmd_opts
64

65
    from modules import codeformer_model
66
    warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
67
    codeformer_model.setup_model(cmd_opts.codeformer_models_path)
68
    startup_timer.record("setup codeformer")
69

70
    from modules import gfpgan_model
71
    gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
72
    startup_timer.record("setup gfpgan")
73

74
    initialize_rest(reload_script_modules=False)
75

76

77
def initialize_rest(*, reload_script_modules=False):
78
    """
79
    Called both from initialize() and when reloading the webui.
80
    """
81
    from modules.shared_cmd_options import cmd_opts
82

83
    from modules import sd_samplers
84
    sd_samplers.set_samplers()
85
    startup_timer.record("set samplers")
86

87
    from modules import extensions
88
    extensions.list_extensions()
89
    startup_timer.record("list extensions")
90

91
    from modules import initialize_util
92
    initialize_util.restore_config_state_file()
93
    startup_timer.record("restore config state file")
94

95
    from modules import shared, upscaler, scripts
96
    if cmd_opts.ui_debug_mode:
97
        shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
98
        scripts.load_scripts()
99
        return
100

101
    from modules import sd_models
102
    sd_models.list_models()
103
    startup_timer.record("list SD models")
104

105
    from modules import localization
106
    localization.list_localizations(cmd_opts.localizations_dir)
107
    startup_timer.record("list localizations")
108

109
    with startup_timer.subcategory("load scripts"):
110
        scripts.load_scripts()
111

112
    if reload_script_modules:
113
        for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
114
            importlib.reload(module)
115
        startup_timer.record("reload script modules")
116

117
    from modules import modelloader
118
    modelloader.load_upscalers()
119
    startup_timer.record("load upscalers")
120

121
    from modules import sd_vae
122
    sd_vae.refresh_vae_list()
123
    startup_timer.record("refresh VAE")
124

125
    from modules import textual_inversion
126
    textual_inversion.textual_inversion.list_textual_inversion_templates()
127
    startup_timer.record("refresh textual inversion templates")
128

129
    from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
130
    script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
131
    sd_hijack.list_optimizers()
132
    startup_timer.record("scripts list_optimizers")
133

134
    from modules import sd_unet
135
    sd_unet.list_unets()
136
    startup_timer.record("scripts list_unets")
137

138
    def load_model():
139
        """
140
        Accesses shared.sd_model property to load model.
141
        After it's available, if it has been loaded before this access by some extension,
142
        its optimization may be None because the list of optimizaers has neet been filled
143
        by that time, so we apply optimization again.
144
        """
145
        from modules import devices
146
        devices.torch_npu_set_device()
147

148
        shared.sd_model  # noqa: B018
149

150
        if sd_hijack.current_optimizer is None:
151
            sd_hijack.apply_optimizations()
152

153
        devices.first_time_calculation()
154
    if not shared.cmd_opts.skip_load_model_at_start:
155
        Thread(target=load_model).start()
156

157
    from modules import shared_items
158
    shared_items.reload_hypernetworks()
159
    startup_timer.record("reload hypernetworks")
160

161
    from modules import ui_extra_networks
162
    ui_extra_networks.initialize()
163
    ui_extra_networks.register_default_pages()
164

165
    from modules import extra_networks
166
    extra_networks.initialize()
167
    extra_networks.register_default_extra_networks()
168
    startup_timer.record("initialize extra networks")
169

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.