stable-diffusion-webui
220 строк · 8.2 Кб
1import base64
2import json
3import warnings
4
5import numpy as np
6import zlib
7from PIL import Image, ImageDraw
8import torch
9
10
11class EmbeddingEncoder(json.JSONEncoder):
12def default(self, obj):
13if isinstance(obj, torch.Tensor):
14return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
15return json.JSONEncoder.default(self, obj)
16
17
18class EmbeddingDecoder(json.JSONDecoder):
19def __init__(self, *args, **kwargs):
20json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
21
22def object_hook(self, d):
23if 'TORCHTENSOR' in d:
24return torch.from_numpy(np.array(d['TORCHTENSOR']))
25return d
26
27
28def embedding_to_b64(data):
29d = json.dumps(data, cls=EmbeddingEncoder)
30return base64.b64encode(d.encode())
31
32
33def embedding_from_b64(data):
34d = base64.b64decode(data)
35return json.loads(d, cls=EmbeddingDecoder)
36
37
38def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
39while True:
40seed = (a * seed + c) % m
41yield seed % 255
42
43
44def xor_block(block):
45g = lcg()
46randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
47return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
48
49
50def style_block(block, sequence):
51im = Image.new('RGB', (block.shape[1], block.shape[0]))
52draw = ImageDraw.Draw(im)
53i = 0
54for x in range(-6, im.size[0], 8):
55for yi, y in enumerate(range(-6, im.size[1], 8)):
56offset = 0
57if yi % 2 == 0:
58offset = 4
59shade = sequence[i % len(sequence)]
60i += 1
61draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
62
63fg = np.array(im).astype(np.uint8) & 0xF0
64
65return block ^ fg
66
67
68def insert_image_data_embed(image, data):
69d = 3
70data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
71data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
72data_np_high = data_np_ >> 4
73data_np_low = data_np_ & 0x0F
74
75h = image.size[1]
76next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
77next_size = next_size + ((h*d)-(next_size % (h*d)))
78
79data_np_low = np.resize(data_np_low, next_size)
80data_np_low = data_np_low.reshape((h, -1, d))
81
82data_np_high = np.resize(data_np_high, next_size)
83data_np_high = data_np_high.reshape((h, -1, d))
84
85edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
86edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
87
88data_np_low = style_block(data_np_low, sequence=edge_style)
89data_np_low = xor_block(data_np_low)
90data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
91data_np_high = xor_block(data_np_high)
92
93im_low = Image.fromarray(data_np_low, mode='RGB')
94im_high = Image.fromarray(data_np_high, mode='RGB')
95
96background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
97background.paste(im_low, (0, 0))
98background.paste(image, (im_low.size[0]+1, 0))
99background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
100
101return background
102
103
104def crop_black(img, tol=0):
105mask = (img > tol).all(2)
106mask0, mask1 = mask.any(0), mask.any(1)
107col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
108row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
109return img[row_start:row_end, col_start:col_end]
110
111
112def extract_image_data_embed(image):
113d = 3
114outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
115black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
116if black_cols[0].shape[0] < 2:
117print('No Image data blocks found.')
118return None
119
120data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
121data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
122
123data_block_lower = xor_block(data_block_lower)
124data_block_upper = xor_block(data_block_upper)
125
126data_block = (data_block_upper << 4) | (data_block_lower)
127data_block = data_block.flatten().tobytes()
128
129data = zlib.decompress(data_block)
130return json.loads(data, cls=EmbeddingDecoder)
131
132
133def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
134from modules.images import get_font
135if textfont:
136warnings.warn(
137'passing in a textfont to caption_image_overlay is deprecated and does nothing',
138DeprecationWarning,
139stacklevel=2,
140)
141from math import cos
142
143image = srcimage.copy()
144fontsize = 32
145factor = 1.5
146gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
147for y in range(image.size[1]):
148mag = 1-cos(y/image.size[1]*factor)
149mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
150gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
151image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
152
153draw = ImageDraw.Draw(image)
154
155font = get_font(fontsize)
156padding = 10
157
158_, _, w, h = draw.textbbox((0, 0), title, font=font)
159fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
160font = get_font(fontsize)
161_, _, w, h = draw.textbbox((0, 0), title, font=font)
162draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
163
164_, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
165fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
166_, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
167fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
168_, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
169fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
170
171font = get_font(min(fontsize_left, fontsize_mid, fontsize_right))
172
173draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
174draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
175draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
176
177return image
178
179
180if __name__ == '__main__':
181
182testEmbed = Image.open('test_embedding.png')
183data = extract_image_data_embed(testEmbed)
184assert data is not None
185
186data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
187assert data is not None
188
189image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
190cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
191
192test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
193
194embedded_image = insert_image_data_embed(cap_image, test_embed)
195
196retrived_embed = extract_image_data_embed(embedded_image)
197
198assert str(retrived_embed) == str(test_embed)
199
200embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
201
202assert embedded_image == embedded_image2
203
204g = lcg()
205shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
206
207reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
20895, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
209160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
21038, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
21130, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
21241, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
21366, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
214204, 86, 73, 222, 44, 198, 118, 240, 97]
215
216assert shared_random == reference_random
217
218hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
219
220assert 12731374 == hunna_kay_random_sum
221