stable-diffusion-webui
98 строк · 3.0 Кб
1import os2import re3
4import torch5import numpy as np6
7from modules import modelloader, paths, deepbooru_model, devices, images, shared8
9re_special = re.compile(r'([\\()])')10
11
12class DeepDanbooru:13def __init__(self):14self.model = None15
16def load(self):17if self.model is not None:18return19
20files = modelloader.load_models(21model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),22model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',23ext_filter=[".pt"],24download_name='model-resnet_custom_v3.pt',25)26
27self.model = deepbooru_model.DeepDanbooruModel()28self.model.load_state_dict(torch.load(files[0], map_location="cpu"))29
30self.model.eval()31self.model.to(devices.cpu, devices.dtype)32
33def start(self):34self.load()35self.model.to(devices.device)36
37def stop(self):38if not shared.opts.interrogate_keep_models_in_memory:39self.model.to(devices.cpu)40devices.torch_gc()41
42def tag(self, pil_image):43self.start()44res = self.tag_multi(pil_image)45self.stop()46
47return res48
49def tag_multi(self, pil_image, force_disable_ranks=False):50threshold = shared.opts.interrogate_deepbooru_score_threshold51use_spaces = shared.opts.deepbooru_use_spaces52use_escape = shared.opts.deepbooru_escape53alpha_sort = shared.opts.deepbooru_sort_alpha54include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks55
56pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)57a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 25558
59with torch.no_grad(), devices.autocast():60x = torch.from_numpy(a).to(devices.device)61y = self.model(x)[0].detach().cpu().numpy()62
63probability_dict = {}64
65for tag, probability in zip(self.model.tags, y):66if probability < threshold:67continue68
69if tag.startswith("rating:"):70continue71
72probability_dict[tag] = probability73
74if alpha_sort:75tags = sorted(probability_dict)76else:77tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]78
79res = []80
81filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}82
83for tag in [x for x in tags if x not in filtertags]:84probability = probability_dict[tag]85tag_outformat = tag86if use_spaces:87tag_outformat = tag_outformat.replace('_', ' ')88if use_escape:89tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)90if include_ranks:91tag_outformat = f"({tag_outformat}:{probability:.3f})"92
93res.append(tag_outformat)94
95return ", ".join(res)96
97
98model = DeepDanbooru()99