stable-diffusion-webui

Форк
0
98 строк · 3.0 Кб
1
import os
2
import re
3

4
import torch
5
import numpy as np
6

7
from modules import modelloader, paths, deepbooru_model, devices, images, shared
8

9
re_special = re.compile(r'([\\()])')
10

11

12
class DeepDanbooru:
13
    def __init__(self):
14
        self.model = None
15

16
    def load(self):
17
        if self.model is not None:
18
            return
19

20
        files = modelloader.load_models(
21
            model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
22
            model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
23
            ext_filter=[".pt"],
24
            download_name='model-resnet_custom_v3.pt',
25
        )
26

27
        self.model = deepbooru_model.DeepDanbooruModel()
28
        self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
29

30
        self.model.eval()
31
        self.model.to(devices.cpu, devices.dtype)
32

33
    def start(self):
34
        self.load()
35
        self.model.to(devices.device)
36

37
    def stop(self):
38
        if not shared.opts.interrogate_keep_models_in_memory:
39
            self.model.to(devices.cpu)
40
            devices.torch_gc()
41

42
    def tag(self, pil_image):
43
        self.start()
44
        res = self.tag_multi(pil_image)
45
        self.stop()
46

47
        return res
48

49
    def tag_multi(self, pil_image, force_disable_ranks=False):
50
        threshold = shared.opts.interrogate_deepbooru_score_threshold
51
        use_spaces = shared.opts.deepbooru_use_spaces
52
        use_escape = shared.opts.deepbooru_escape
53
        alpha_sort = shared.opts.deepbooru_sort_alpha
54
        include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
55

56
        pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
57
        a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
58

59
        with torch.no_grad(), devices.autocast():
60
            x = torch.from_numpy(a).to(devices.device)
61
            y = self.model(x)[0].detach().cpu().numpy()
62

63
        probability_dict = {}
64

65
        for tag, probability in zip(self.model.tags, y):
66
            if probability < threshold:
67
                continue
68

69
            if tag.startswith("rating:"):
70
                continue
71

72
            probability_dict[tag] = probability
73

74
        if alpha_sort:
75
            tags = sorted(probability_dict)
76
        else:
77
            tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
78

79
        res = []
80

81
        filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
82

83
        for tag in [x for x in tags if x not in filtertags]:
84
            probability = probability_dict[tag]
85
            tag_outformat = tag
86
            if use_spaces:
87
                tag_outformat = tag_outformat.replace('_', ' ')
88
            if use_escape:
89
                tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
90
            if include_ranks:
91
                tag_outformat = f"({tag_outformat}:{probability:.3f})"
92

93
            res.append(tag_outformat)
94

95
        return ", ".join(res)
96

97

98
model = DeepDanbooru()
99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.