stable-diffusion-webui

Форк
0
283 строки · 8.3 Кб
1
import os
2
import collections
3
from dataclasses import dataclass
4

5
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
6

7
import glob
8
from copy import deepcopy
9

10

11
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
12
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
13
vae_dict = {}
14

15

16
base_vae = None
17
loaded_vae_file = None
18
checkpoint_info = None
19

20
checkpoints_loaded = collections.OrderedDict()
21

22

23
def get_loaded_vae_name():
24
    if loaded_vae_file is None:
25
        return None
26

27
    return os.path.basename(loaded_vae_file)
28

29

30
def get_loaded_vae_hash():
31
    if loaded_vae_file is None:
32
        return None
33

34
    sha256 = hashes.sha256(loaded_vae_file, 'vae')
35

36
    return sha256[0:10] if sha256 else None
37

38

39
def get_base_vae(model):
40
    if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
41
        return base_vae
42
    return None
43

44

45
def store_base_vae(model):
46
    global base_vae, checkpoint_info
47
    if checkpoint_info != model.sd_checkpoint_info:
48
        assert not loaded_vae_file, "Trying to store non-base VAE!"
49
        base_vae = deepcopy(model.first_stage_model.state_dict())
50
        checkpoint_info = model.sd_checkpoint_info
51

52

53
def delete_base_vae():
54
    global base_vae, checkpoint_info
55
    base_vae = None
56
    checkpoint_info = None
57

58

59
def restore_base_vae(model):
60
    global loaded_vae_file
61
    if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
62
        print("Restoring base VAE")
63
        _load_vae_dict(model, base_vae)
64
        loaded_vae_file = None
65
    delete_base_vae()
66

67

68
def get_filename(filepath):
69
    return os.path.basename(filepath)
70

71

72
def refresh_vae_list():
73
    vae_dict.clear()
74

75
    paths = [
76
        os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
77
        os.path.join(sd_models.model_path, '**/*.vae.pt'),
78
        os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
79
        os.path.join(vae_path, '**/*.ckpt'),
80
        os.path.join(vae_path, '**/*.pt'),
81
        os.path.join(vae_path, '**/*.safetensors'),
82
    ]
83

84
    if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
85
        paths += [
86
            os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
87
            os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
88
            os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
89
        ]
90

91
    if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
92
        paths += [
93
            os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
94
            os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
95
            os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
96
        ]
97

98
    candidates = []
99
    for path in paths:
100
        candidates += glob.iglob(path, recursive=True)
101

102
    for filepath in candidates:
103
        name = get_filename(filepath)
104
        vae_dict[name] = filepath
105

106
    vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))))
107

108

109
def find_vae_near_checkpoint(checkpoint_file):
110
    checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
111
    for vae_file in vae_dict.values():
112
        if os.path.basename(vae_file).startswith(checkpoint_path):
113
            return vae_file
114

115
    return None
116

117

118
@dataclass
119
class VaeResolution:
120
    vae: str = None
121
    source: str = None
122
    resolved: bool = True
123

124
    def tuple(self):
125
        return self.vae, self.source
126

127

128
def is_automatic():
129
    return shared.opts.sd_vae in {"Automatic", "auto"}  # "auto" for people with old config
130

131

132
def resolve_vae_from_setting() -> VaeResolution:
133
    if shared.opts.sd_vae == "None":
134
        return VaeResolution()
135

136
    vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
137
    if vae_from_options is not None:
138
        return VaeResolution(vae_from_options, 'specified in settings')
139

140
    if not is_automatic():
141
        print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
142

143
    return VaeResolution(resolved=False)
144

145

146
def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
147
    metadata = extra_networks.get_user_metadata(checkpoint_file)
148
    vae_metadata = metadata.get("vae", None)
149
    if vae_metadata is not None and vae_metadata != "Automatic":
150
        if vae_metadata == "None":
151
            return VaeResolution()
152

153
        vae_from_metadata = vae_dict.get(vae_metadata, None)
154
        if vae_from_metadata is not None:
155
            return VaeResolution(vae_from_metadata, "from user metadata")
156

157
    return VaeResolution(resolved=False)
158

159

160
def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
161
    vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
162
    if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic()):
163
        return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
164

165
    return VaeResolution(resolved=False)
166

167

168
def resolve_vae(checkpoint_file) -> VaeResolution:
169
    if shared.cmd_opts.vae_path is not None:
170
        return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
171

172
    if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
173
        return resolve_vae_from_setting()
174

175
    res = resolve_vae_from_user_metadata(checkpoint_file)
176
    if res.resolved:
177
        return res
178

179
    res = resolve_vae_near_checkpoint(checkpoint_file)
180
    if res.resolved:
181
        return res
182

183
    res = resolve_vae_from_setting()
184

185
    return res
186

187

188
def load_vae_dict(filename, map_location):
189
    vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
190
    vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
191
    return vae_dict_1
192

193

194
def load_vae(model, vae_file=None, vae_source="from unknown source"):
195
    global vae_dict, base_vae, loaded_vae_file
196
    # save_settings = False
197

198
    cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
199

200
    if vae_file:
201
        if cache_enabled and vae_file in checkpoints_loaded:
202
            # use vae checkpoint cache
203
            print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
204
            store_base_vae(model)
205
            _load_vae_dict(model, checkpoints_loaded[vae_file])
206
        else:
207
            assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
208
            print(f"Loading VAE weights {vae_source}: {vae_file}")
209
            store_base_vae(model)
210

211
            vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
212
            _load_vae_dict(model, vae_dict_1)
213

214
            if cache_enabled:
215
                # cache newly loaded vae
216
                checkpoints_loaded[vae_file] = vae_dict_1.copy()
217

218
        # clean up cache if limit is reached
219
        if cache_enabled:
220
            while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
221
                checkpoints_loaded.popitem(last=False)  # LRU
222

223
        # If vae used is not in dict, update it
224
        # It will be removed on refresh though
225
        vae_opt = get_filename(vae_file)
226
        if vae_opt not in vae_dict:
227
            vae_dict[vae_opt] = vae_file
228

229
    elif loaded_vae_file:
230
        restore_base_vae(model)
231

232
    loaded_vae_file = vae_file
233
    model.base_vae = base_vae
234
    model.loaded_vae_file = loaded_vae_file
235

236

237
# don't call this from outside
238
def _load_vae_dict(model, vae_dict_1):
239
    model.first_stage_model.load_state_dict(vae_dict_1)
240
    model.first_stage_model.to(devices.dtype_vae)
241

242

243
def clear_loaded_vae():
244
    global loaded_vae_file
245
    loaded_vae_file = None
246

247

248
unspecified = object()
249

250

251
def reload_vae_weights(sd_model=None, vae_file=unspecified):
252
    if not sd_model:
253
        sd_model = shared.sd_model
254

255
    checkpoint_info = sd_model.sd_checkpoint_info
256
    checkpoint_file = checkpoint_info.filename
257

258
    if vae_file == unspecified:
259
        vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
260
    else:
261
        vae_source = "from function argument"
262

263
    if loaded_vae_file == vae_file:
264
        return
265

266
    if sd_model.lowvram:
267
        lowvram.send_everything_to_cpu()
268
    else:
269
        sd_model.to(devices.cpu)
270

271
    sd_hijack.model_hijack.undo_hijack(sd_model)
272

273
    load_vae(sd_model, vae_file, vae_source)
274

275
    sd_hijack.model_hijack.hijack(sd_model)
276

277
    if not sd_model.lowvram:
278
        sd_model.to(devices.device)
279

280
    script_callbacks.model_loaded_callback(sd_model)
281

282
    print("VAE weights loaded.")
283
    return sd_model
284

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.