stable-diffusion-webui
93 строки · 2.6 Кб
1import torch.nn
2
3from modules import script_callbacks, shared, devices
4
5unet_options = []
6current_unet_option = None
7current_unet = None
8original_forward = None # not used, only left temporarily for compatibility
9
10def list_unets():
11new_unets = script_callbacks.list_unets_callback()
12
13unet_options.clear()
14unet_options.extend(new_unets)
15
16
17def get_unet_option(option=None):
18option = option or shared.opts.sd_unet
19
20if option == "None":
21return None
22
23if option == "Automatic":
24name = shared.sd_model.sd_checkpoint_info.model_name
25
26options = [x for x in unet_options if x.model_name == name]
27
28option = options[0].label if options else "None"
29
30return next(iter([x for x in unet_options if x.label == option]), None)
31
32
33def apply_unet(option=None):
34global current_unet_option
35global current_unet
36
37new_option = get_unet_option(option)
38if new_option == current_unet_option:
39return
40
41if current_unet is not None:
42print(f"Dectivating unet: {current_unet.option.label}")
43current_unet.deactivate()
44
45current_unet_option = new_option
46if current_unet_option is None:
47current_unet = None
48
49if not shared.sd_model.lowvram:
50shared.sd_model.model.diffusion_model.to(devices.device)
51
52return
53
54shared.sd_model.model.diffusion_model.to(devices.cpu)
55devices.torch_gc()
56
57current_unet = current_unet_option.create_unet()
58current_unet.option = current_unet_option
59print(f"Activating unet: {current_unet.option.label}")
60current_unet.activate()
61
62
63class SdUnetOption:
64model_name = None
65"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
66
67label = None
68"""name of the unet in UI"""
69
70def create_unet(self):
71"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
72raise NotImplementedError()
73
74
75class SdUnet(torch.nn.Module):
76def forward(self, x, timesteps, context, *args, **kwargs):
77raise NotImplementedError()
78
79def activate(self):
80pass
81
82def deactivate(self):
83pass
84
85
86def create_unet_forward(original_forward):
87def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
88if current_unet is not None:
89return current_unet.forward(x, timesteps, context, *args, **kwargs)
90
91return original_forward(self, x, timesteps, context, *args, **kwargs)
92
93return UNetModel_forward
94
95