stable-diffusion-webui
147 строк · 5.3 Кб
1import torch2from modules import devices, shared3
4module_in_gpu = None5cpu = torch.device("cpu")6
7
8def send_everything_to_cpu():9global module_in_gpu10
11if module_in_gpu is not None:12module_in_gpu.to(cpu)13
14module_in_gpu = None15
16
17def is_needed(sd_model):18return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')19
20
21def apply(sd_model):22enable = is_needed(sd_model)23shared.parallel_processing_allowed = not enable24
25if enable:26setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)27else:28sd_model.lowvram = False29
30
31def setup_for_low_vram(sd_model, use_medvram):32if getattr(sd_model, 'lowvram', False):33return34
35sd_model.lowvram = True36
37parents = {}38
39def send_me_to_gpu(module, _):40"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;41we add this as forward_pre_hook to a lot of modules and this way all but one of them will
42be in CPU
43"""
44global module_in_gpu45
46module = parents.get(module, module)47
48if module_in_gpu == module:49return50
51if module_in_gpu is not None:52module_in_gpu.to(cpu)53
54module.to(devices.device)55module_in_gpu = module56
57# see below for register_forward_pre_hook;58# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is59# useless here, and we just replace those methods60
61first_stage_model = sd_model.first_stage_model62first_stage_model_encode = sd_model.first_stage_model.encode63first_stage_model_decode = sd_model.first_stage_model.decode64
65def first_stage_model_encode_wrap(x):66send_me_to_gpu(first_stage_model, None)67return first_stage_model_encode(x)68
69def first_stage_model_decode_wrap(z):70send_me_to_gpu(first_stage_model, None)71return first_stage_model_decode(z)72
73to_remain_in_cpu = [74(sd_model, 'first_stage_model'),75(sd_model, 'depth_model'),76(sd_model, 'embedder'),77(sd_model, 'model'),78(sd_model, 'embedder'),79]80
81is_sdxl = hasattr(sd_model, 'conditioner')82is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')83
84if is_sdxl:85to_remain_in_cpu.append((sd_model, 'conditioner'))86elif is_sd2:87to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))88else:89to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))90
91# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model92stored = []93for obj, field in to_remain_in_cpu:94module = getattr(obj, field, None)95stored.append(module)96setattr(obj, field, None)97
98# send the model to GPU.99sd_model.to(devices.device)100
101# put modules back. the modules will be in CPU.102for (obj, field), module in zip(to_remain_in_cpu, stored):103setattr(obj, field, module)104
105# register hooks for those the first three models106if is_sdxl:107sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)108elif is_sd2:109sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)110sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)111parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model112parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model113else:114sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)115parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model116
117sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)118sd_model.first_stage_model.encode = first_stage_model_encode_wrap119sd_model.first_stage_model.decode = first_stage_model_decode_wrap120if sd_model.depth_model:121sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)122if sd_model.embedder:123sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)124
125if use_medvram:126sd_model.model.register_forward_pre_hook(send_me_to_gpu)127else:128diff_model = sd_model.model.diffusion_model129
130# the third remaining model is still too big for 4 GB, so we also do the same for its submodules131# so that only one of them is in GPU at a time132stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed133diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None134sd_model.model.to(devices.device)135diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored136
137# install hooks for bits of third model138diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)139for block in diff_model.input_blocks:140block.register_forward_pre_hook(send_me_to_gpu)141diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)142for block in diff_model.output_blocks:143block.register_forward_pre_hook(send_me_to_gpu)144
145
146def is_enabled(sd_model):147return sd_model.lowvram148