stable-diffusion-webui

Форк
0
220 строк · 8.2 Кб
1
import base64
2
import json
3
import warnings
4

5
import numpy as np
6
import zlib
7
from PIL import Image, ImageDraw
8
import torch
9

10

11
class EmbeddingEncoder(json.JSONEncoder):
12
    def default(self, obj):
13
        if isinstance(obj, torch.Tensor):
14
            return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
15
        return json.JSONEncoder.default(self, obj)
16

17

18
class EmbeddingDecoder(json.JSONDecoder):
19
    def __init__(self, *args, **kwargs):
20
        json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
21

22
    def object_hook(self, d):
23
        if 'TORCHTENSOR' in d:
24
            return torch.from_numpy(np.array(d['TORCHTENSOR']))
25
        return d
26

27

28
def embedding_to_b64(data):
29
    d = json.dumps(data, cls=EmbeddingEncoder)
30
    return base64.b64encode(d.encode())
31

32

33
def embedding_from_b64(data):
34
    d = base64.b64decode(data)
35
    return json.loads(d, cls=EmbeddingDecoder)
36

37

38
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
39
    while True:
40
        seed = (a * seed + c) % m
41
        yield seed % 255
42

43

44
def xor_block(block):
45
    g = lcg()
46
    randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
47
    return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
48

49

50
def style_block(block, sequence):
51
    im = Image.new('RGB', (block.shape[1], block.shape[0]))
52
    draw = ImageDraw.Draw(im)
53
    i = 0
54
    for x in range(-6, im.size[0], 8):
55
        for yi, y in enumerate(range(-6, im.size[1], 8)):
56
            offset = 0
57
            if yi % 2 == 0:
58
                offset = 4
59
            shade = sequence[i % len(sequence)]
60
            i += 1
61
            draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
62

63
    fg = np.array(im).astype(np.uint8) & 0xF0
64

65
    return block ^ fg
66

67

68
def insert_image_data_embed(image, data):
69
    d = 3
70
    data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
71
    data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
72
    data_np_high = data_np_ >> 4
73
    data_np_low = data_np_ & 0x0F
74

75
    h = image.size[1]
76
    next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
77
    next_size = next_size + ((h*d)-(next_size % (h*d)))
78

79
    data_np_low = np.resize(data_np_low, next_size)
80
    data_np_low = data_np_low.reshape((h, -1, d))
81

82
    data_np_high = np.resize(data_np_high, next_size)
83
    data_np_high = data_np_high.reshape((h, -1, d))
84

85
    edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
86
    edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
87

88
    data_np_low = style_block(data_np_low, sequence=edge_style)
89
    data_np_low = xor_block(data_np_low)
90
    data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
91
    data_np_high = xor_block(data_np_high)
92

93
    im_low = Image.fromarray(data_np_low, mode='RGB')
94
    im_high = Image.fromarray(data_np_high, mode='RGB')
95

96
    background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
97
    background.paste(im_low, (0, 0))
98
    background.paste(image, (im_low.size[0]+1, 0))
99
    background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
100

101
    return background
102

103

104
def crop_black(img, tol=0):
105
    mask = (img > tol).all(2)
106
    mask0, mask1 = mask.any(0), mask.any(1)
107
    col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
108
    row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
109
    return img[row_start:row_end, col_start:col_end]
110

111

112
def extract_image_data_embed(image):
113
    d = 3
114
    outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
115
    black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
116
    if black_cols[0].shape[0] < 2:
117
        print('No Image data blocks found.')
118
        return None
119

120
    data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
121
    data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
122

123
    data_block_lower = xor_block(data_block_lower)
124
    data_block_upper = xor_block(data_block_upper)
125

126
    data_block = (data_block_upper << 4) | (data_block_lower)
127
    data_block = data_block.flatten().tobytes()
128

129
    data = zlib.decompress(data_block)
130
    return json.loads(data, cls=EmbeddingDecoder)
131

132

133
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
134
    from modules.images import get_font
135
    if textfont:
136
        warnings.warn(
137
            'passing in a textfont to caption_image_overlay is deprecated and does nothing',
138
            DeprecationWarning,
139
            stacklevel=2,
140
        )
141
    from math import cos
142

143
    image = srcimage.copy()
144
    fontsize = 32
145
    factor = 1.5
146
    gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
147
    for y in range(image.size[1]):
148
        mag = 1-cos(y/image.size[1]*factor)
149
        mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
150
        gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
151
    image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
152

153
    draw = ImageDraw.Draw(image)
154

155
    font = get_font(fontsize)
156
    padding = 10
157

158
    _, _, w, h = draw.textbbox((0, 0), title, font=font)
159
    fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
160
    font = get_font(fontsize)
161
    _, _, w, h = draw.textbbox((0, 0), title, font=font)
162
    draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
163

164
    _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
165
    fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
166
    _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
167
    fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
168
    _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
169
    fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
170

171
    font = get_font(min(fontsize_left, fontsize_mid, fontsize_right))
172

173
    draw.text((padding, image.size[1]-padding),               footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
174
    draw.text((image.size[0]/2, image.size[1]-padding),       footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
175
    draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
176

177
    return image
178

179

180
if __name__ == '__main__':
181

182
    testEmbed = Image.open('test_embedding.png')
183
    data = extract_image_data_embed(testEmbed)
184
    assert data is not None
185

186
    data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
187
    assert data is not None
188

189
    image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
190
    cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
191

192
    test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
193

194
    embedded_image = insert_image_data_embed(cap_image, test_embed)
195

196
    retrived_embed = extract_image_data_embed(embedded_image)
197

198
    assert str(retrived_embed) == str(test_embed)
199

200
    embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
201

202
    assert embedded_image == embedded_image2
203

204
    g = lcg()
205
    shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
206

207
    reference_random = [253, 242, 127,  44, 157,  27, 239, 133,  38,  79, 167,   4, 177,
208
                         95, 130,  79,  78,  14,  52, 215, 220, 194, 126,  28, 240, 179,
209
                        160, 153, 149,  50, 105,  14,  21, 218, 199,  18,  54, 198, 193,
210
                         38, 128,  19,  53, 195, 124,  75, 205,  12,   6, 145,   0,  28,
211
                         30, 148,   8,  45, 218, 171,  55, 249,  97, 166,  12,  35,   0,
212
                         41, 221, 122, 215, 170,  31, 113, 186,  97, 119,  31,  23, 185,
213
                         66, 140,  30,  41,  37,  63, 137, 109, 216,  55, 159, 145,  82,
214
                         204, 86,  73, 222,  44, 198, 118, 240,  97]
215

216
    assert shared_random == reference_random
217

218
    hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
219

220
    assert 12731374 == hunna_kay_random_sum
221

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.