stable-diffusion-webui
232 строки · 11.2 Кб
1import ldm.modules.encoders.modules2import open_clip3import torch4import transformers.utils.hub5
6from modules import shared7
8
9class ReplaceHelper:10def __init__(self):11self.replaced = []12
13def replace(self, obj, field, func):14original = getattr(obj, field, None)15if original is None:16return None17
18self.replaced.append((obj, field, original))19setattr(obj, field, func)20
21return original22
23def restore(self):24for obj, field, original in self.replaced:25setattr(obj, field, original)26
27self.replaced.clear()28
29
30class DisableInitialization(ReplaceHelper):31"""32When an object of this class enters a `with` block, it starts:
33- preventing torch's layer initialization functions from working
34- changes CLIP and OpenCLIP to not download model weights
35- changes CLIP to not make requests to check if there is a new version of a file you already have
36
37When it leaves the block, it reverts everything to how it was before.
38
39Use it like this:
40```
41with DisableInitialization():
42do_things()
43```
44"""
45
46def __init__(self, disable_clip=True):47super().__init__()48self.disable_clip = disable_clip49
50def replace(self, obj, field, func):51original = getattr(obj, field, None)52if original is None:53return None54
55self.replaced.append((obj, field, original))56setattr(obj, field, func)57
58return original59
60def __enter__(self):61def do_nothing(*args, **kwargs):62pass63
64def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):65return self.create_model_and_transforms(*args, pretrained=None, **kwargs)66
67def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):68res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)69res.name_or_path = pretrained_model_name_or_path70return res71
72def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):73args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug74return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)75
76def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):77
78# this file is always 404, prevent making request79if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':80return None81
82try:83res = original(url, *args, local_files_only=True, **kwargs)84if res is None:85res = original(url, *args, local_files_only=False, **kwargs)86return res87except Exception:88return original(url, *args, local_files_only=False, **kwargs)89
90def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):91return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)92
93def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):94return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)95
96def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):97return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)98
99self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)100self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)101self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)102
103if self.disable_clip:104self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)105self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)106self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)107self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)108self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)109self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)110
111def __exit__(self, exc_type, exc_val, exc_tb):112self.restore()113
114
115class InitializeOnMeta(ReplaceHelper):116"""117Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
118which results in those parameters having no values and taking no memory. model.to() will be broken and
119will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
120
121Usage:
122```
123with sd_disable_initialization.InitializeOnMeta():
124sd_model = instantiate_from_config(sd_config.model)
125```
126"""
127
128def __enter__(self):129if shared.cmd_opts.disable_model_loading_ram_optimization:130return131
132def set_device(x):133x["device"] = "meta"134return x135
136linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))137conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))138mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))139self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)140
141def __exit__(self, exc_type, exc_val, exc_tb):142self.restore()143
144
145class LoadStateDictOnMeta(ReplaceHelper):146"""147Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
148As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
149Meant to be used together with InitializeOnMeta above.
150
151Usage:
152```
153with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
154model.load_state_dict(state_dict, strict=False)
155```
156"""
157
158def __init__(self, state_dict, device, weight_dtype_conversion=None):159super().__init__()160self.state_dict = state_dict161self.device = device162self.weight_dtype_conversion = weight_dtype_conversion or {}163self.default_dtype = self.weight_dtype_conversion.get('')164
165def get_weight_dtype(self, key):166key_first_term, _ = key.split('.', 1)167return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)168
169def __enter__(self):170if shared.cmd_opts.disable_model_loading_ram_optimization:171return172
173sd = self.state_dict174device = self.device175
176def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):177used_param_keys = []178
179for name, param in module._parameters.items():180if param is None:181continue182
183key = prefix + name184sd_param = sd.pop(key, None)185if sd_param is not None:186state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))187used_param_keys.append(key)188
189if param.is_meta:190dtype = sd_param.dtype if sd_param is not None else param.dtype191module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)192
193for name in module._buffers:194key = prefix + name195
196sd_param = sd.pop(key, None)197if sd_param is not None:198state_dict[key] = sd_param199used_param_keys.append(key)200
201original(module, state_dict, prefix, *args, **kwargs)202
203for key in used_param_keys:204state_dict.pop(key, None)205
206def load_state_dict(original, module, state_dict, strict=True):207"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help208because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
209all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
210
211In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
212
213The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
214the function and does not call the original) the state dict will just fail to load because weights
215would be on the meta device.
216"""
217
218if state_dict is sd:219state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}220
221original(module, state_dict, strict=strict)222
223module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))224module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))225linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))226conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))227mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))228layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))229group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))230
231def __exit__(self, exc_type, exc_val, exc_tb):232self.restore()233