stable-diffusion-webui

Форк
0
/
sd_disable_initialization.py 
232 строки · 11.2 Кб
1
import ldm.modules.encoders.modules
2
import open_clip
3
import torch
4
import transformers.utils.hub
5

6
from modules import shared
7

8

9
class ReplaceHelper:
10
    def __init__(self):
11
        self.replaced = []
12

13
    def replace(self, obj, field, func):
14
        original = getattr(obj, field, None)
15
        if original is None:
16
            return None
17

18
        self.replaced.append((obj, field, original))
19
        setattr(obj, field, func)
20

21
        return original
22

23
    def restore(self):
24
        for obj, field, original in self.replaced:
25
            setattr(obj, field, original)
26

27
        self.replaced.clear()
28

29

30
class DisableInitialization(ReplaceHelper):
31
    """
32
    When an object of this class enters a `with` block, it starts:
33
    - preventing torch's layer initialization functions from working
34
    - changes CLIP and OpenCLIP to not download model weights
35
    - changes CLIP to not make requests to check if there is a new version of a file you already have
36

37
    When it leaves the block, it reverts everything to how it was before.
38

39
    Use it like this:
40
    ```
41
    with DisableInitialization():
42
        do_things()
43
    ```
44
    """
45

46
    def __init__(self, disable_clip=True):
47
        super().__init__()
48
        self.disable_clip = disable_clip
49

50
    def replace(self, obj, field, func):
51
        original = getattr(obj, field, None)
52
        if original is None:
53
            return None
54

55
        self.replaced.append((obj, field, original))
56
        setattr(obj, field, func)
57

58
        return original
59

60
    def __enter__(self):
61
        def do_nothing(*args, **kwargs):
62
            pass
63

64
        def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
65
            return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
66

67
        def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
68
            res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
69
            res.name_or_path = pretrained_model_name_or_path
70
            return res
71

72
        def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
73
            args = args[0:3] + ('/', ) + args[4:]  # resolved_archive_file; must set it to something to prevent what seems to be a bug
74
            return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
75

76
        def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
77

78
            # this file is always 404, prevent making request
79
            if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
80
                return None
81

82
            try:
83
                res = original(url, *args, local_files_only=True, **kwargs)
84
                if res is None:
85
                    res = original(url, *args, local_files_only=False, **kwargs)
86
                return res
87
            except Exception:
88
                return original(url, *args, local_files_only=False, **kwargs)
89

90
        def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
91
            return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
92

93
        def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
94
            return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
95

96
        def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
97
            return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
98

99
        self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
100
        self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
101
        self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
102

103
        if self.disable_clip:
104
            self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
105
            self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
106
            self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
107
            self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
108
            self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
109
            self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
110

111
    def __exit__(self, exc_type, exc_val, exc_tb):
112
        self.restore()
113

114

115
class InitializeOnMeta(ReplaceHelper):
116
    """
117
    Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
118
    which results in those parameters having no values and taking no memory. model.to() will be broken and
119
    will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
120

121
    Usage:
122
    ```
123
    with sd_disable_initialization.InitializeOnMeta():
124
        sd_model = instantiate_from_config(sd_config.model)
125
    ```
126
    """
127

128
    def __enter__(self):
129
        if shared.cmd_opts.disable_model_loading_ram_optimization:
130
            return
131

132
        def set_device(x):
133
            x["device"] = "meta"
134
            return x
135

136
        linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
137
        conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
138
        mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
139
        self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
140

141
    def __exit__(self, exc_type, exc_val, exc_tb):
142
        self.restore()
143

144

145
class LoadStateDictOnMeta(ReplaceHelper):
146
    """
147
    Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
148
    As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
149
    Meant to be used together with InitializeOnMeta above.
150

151
    Usage:
152
    ```
153
    with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
154
        model.load_state_dict(state_dict, strict=False)
155
    ```
156
    """
157

158
    def __init__(self, state_dict, device, weight_dtype_conversion=None):
159
        super().__init__()
160
        self.state_dict = state_dict
161
        self.device = device
162
        self.weight_dtype_conversion = weight_dtype_conversion or {}
163
        self.default_dtype = self.weight_dtype_conversion.get('')
164

165
    def get_weight_dtype(self, key):
166
        key_first_term, _ = key.split('.', 1)
167
        return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
168

169
    def __enter__(self):
170
        if shared.cmd_opts.disable_model_loading_ram_optimization:
171
            return
172

173
        sd = self.state_dict
174
        device = self.device
175

176
        def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
177
            used_param_keys = []
178

179
            for name, param in module._parameters.items():
180
                if param is None:
181
                    continue
182

183
                key = prefix + name
184
                sd_param = sd.pop(key, None)
185
                if sd_param is not None:
186
                    state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
187
                    used_param_keys.append(key)
188

189
                if param.is_meta:
190
                    dtype = sd_param.dtype if sd_param is not None else param.dtype
191
                    module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
192

193
            for name in module._buffers:
194
                key = prefix + name
195

196
                sd_param = sd.pop(key, None)
197
                if sd_param is not None:
198
                    state_dict[key] = sd_param
199
                    used_param_keys.append(key)
200

201
            original(module, state_dict, prefix, *args, **kwargs)
202

203
            for key in used_param_keys:
204
                state_dict.pop(key, None)
205

206
        def load_state_dict(original, module, state_dict, strict=True):
207
            """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
208
            because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
209
            all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
210

211
            In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
212

213
            The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
214
            the function and does not call the original) the state dict will just fail to load because weights
215
            would be on the meta device.
216
            """
217

218
            if state_dict is sd:
219
                state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
220

221
            original(module, state_dict, strict=strict)
222

223
        module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
224
        module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
225
        linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
226
        conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
227
        mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
228
        layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
229
        group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
230

231
    def __exit__(self, exc_type, exc_val, exc_tb):
232
        self.restore()
233

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.