stable-diffusion-webui
283 строки · 8.3 Кб
1import os
2import collections
3from dataclasses import dataclass
4
5from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
6
7import glob
8from copy import deepcopy
9
10
11vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
12vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
13vae_dict = {}
14
15
16base_vae = None
17loaded_vae_file = None
18checkpoint_info = None
19
20checkpoints_loaded = collections.OrderedDict()
21
22
23def get_loaded_vae_name():
24if loaded_vae_file is None:
25return None
26
27return os.path.basename(loaded_vae_file)
28
29
30def get_loaded_vae_hash():
31if loaded_vae_file is None:
32return None
33
34sha256 = hashes.sha256(loaded_vae_file, 'vae')
35
36return sha256[0:10] if sha256 else None
37
38
39def get_base_vae(model):
40if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
41return base_vae
42return None
43
44
45def store_base_vae(model):
46global base_vae, checkpoint_info
47if checkpoint_info != model.sd_checkpoint_info:
48assert not loaded_vae_file, "Trying to store non-base VAE!"
49base_vae = deepcopy(model.first_stage_model.state_dict())
50checkpoint_info = model.sd_checkpoint_info
51
52
53def delete_base_vae():
54global base_vae, checkpoint_info
55base_vae = None
56checkpoint_info = None
57
58
59def restore_base_vae(model):
60global loaded_vae_file
61if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
62print("Restoring base VAE")
63_load_vae_dict(model, base_vae)
64loaded_vae_file = None
65delete_base_vae()
66
67
68def get_filename(filepath):
69return os.path.basename(filepath)
70
71
72def refresh_vae_list():
73vae_dict.clear()
74
75paths = [
76os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
77os.path.join(sd_models.model_path, '**/*.vae.pt'),
78os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
79os.path.join(vae_path, '**/*.ckpt'),
80os.path.join(vae_path, '**/*.pt'),
81os.path.join(vae_path, '**/*.safetensors'),
82]
83
84if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
85paths += [
86os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
87os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
88os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
89]
90
91if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
92paths += [
93os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
94os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
95os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
96]
97
98candidates = []
99for path in paths:
100candidates += glob.iglob(path, recursive=True)
101
102for filepath in candidates:
103name = get_filename(filepath)
104vae_dict[name] = filepath
105
106vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))))
107
108
109def find_vae_near_checkpoint(checkpoint_file):
110checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
111for vae_file in vae_dict.values():
112if os.path.basename(vae_file).startswith(checkpoint_path):
113return vae_file
114
115return None
116
117
118@dataclass
119class VaeResolution:
120vae: str = None
121source: str = None
122resolved: bool = True
123
124def tuple(self):
125return self.vae, self.source
126
127
128def is_automatic():
129return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
130
131
132def resolve_vae_from_setting() -> VaeResolution:
133if shared.opts.sd_vae == "None":
134return VaeResolution()
135
136vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
137if vae_from_options is not None:
138return VaeResolution(vae_from_options, 'specified in settings')
139
140if not is_automatic():
141print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
142
143return VaeResolution(resolved=False)
144
145
146def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
147metadata = extra_networks.get_user_metadata(checkpoint_file)
148vae_metadata = metadata.get("vae", None)
149if vae_metadata is not None and vae_metadata != "Automatic":
150if vae_metadata == "None":
151return VaeResolution()
152
153vae_from_metadata = vae_dict.get(vae_metadata, None)
154if vae_from_metadata is not None:
155return VaeResolution(vae_from_metadata, "from user metadata")
156
157return VaeResolution(resolved=False)
158
159
160def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
161vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
162if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic()):
163return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
164
165return VaeResolution(resolved=False)
166
167
168def resolve_vae(checkpoint_file) -> VaeResolution:
169if shared.cmd_opts.vae_path is not None:
170return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
171
172if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
173return resolve_vae_from_setting()
174
175res = resolve_vae_from_user_metadata(checkpoint_file)
176if res.resolved:
177return res
178
179res = resolve_vae_near_checkpoint(checkpoint_file)
180if res.resolved:
181return res
182
183res = resolve_vae_from_setting()
184
185return res
186
187
188def load_vae_dict(filename, map_location):
189vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
190vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
191return vae_dict_1
192
193
194def load_vae(model, vae_file=None, vae_source="from unknown source"):
195global vae_dict, base_vae, loaded_vae_file
196# save_settings = False
197
198cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
199
200if vae_file:
201if cache_enabled and vae_file in checkpoints_loaded:
202# use vae checkpoint cache
203print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
204store_base_vae(model)
205_load_vae_dict(model, checkpoints_loaded[vae_file])
206else:
207assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
208print(f"Loading VAE weights {vae_source}: {vae_file}")
209store_base_vae(model)
210
211vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
212_load_vae_dict(model, vae_dict_1)
213
214if cache_enabled:
215# cache newly loaded vae
216checkpoints_loaded[vae_file] = vae_dict_1.copy()
217
218# clean up cache if limit is reached
219if cache_enabled:
220while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
221checkpoints_loaded.popitem(last=False) # LRU
222
223# If vae used is not in dict, update it
224# It will be removed on refresh though
225vae_opt = get_filename(vae_file)
226if vae_opt not in vae_dict:
227vae_dict[vae_opt] = vae_file
228
229elif loaded_vae_file:
230restore_base_vae(model)
231
232loaded_vae_file = vae_file
233model.base_vae = base_vae
234model.loaded_vae_file = loaded_vae_file
235
236
237# don't call this from outside
238def _load_vae_dict(model, vae_dict_1):
239model.first_stage_model.load_state_dict(vae_dict_1)
240model.first_stage_model.to(devices.dtype_vae)
241
242
243def clear_loaded_vae():
244global loaded_vae_file
245loaded_vae_file = None
246
247
248unspecified = object()
249
250
251def reload_vae_weights(sd_model=None, vae_file=unspecified):
252if not sd_model:
253sd_model = shared.sd_model
254
255checkpoint_info = sd_model.sd_checkpoint_info
256checkpoint_file = checkpoint_info.filename
257
258if vae_file == unspecified:
259vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
260else:
261vae_source = "from function argument"
262
263if loaded_vae_file == vae_file:
264return
265
266if sd_model.lowvram:
267lowvram.send_everything_to_cpu()
268else:
269sd_model.to(devices.cpu)
270
271sd_hijack.model_hijack.undo_hijack(sd_model)
272
273load_vae(sd_model, vae_file, vae_source)
274
275sd_hijack.model_hijack.hijack(sd_model)
276
277if not sd_model.lowvram:
278sd_model.to(devices.device)
279
280script_callbacks.model_loaded_callback(sd_model)
281
282print("VAE weights loaded.")
283return sd_model
284