2
Modified from https://github.com/mlomnitz/DiffJPEG
4
For images not divisible by 8
5
https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
11
from torch.nn import functional as F
13
# ------------------------ utils ------------------------#
15
[[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
16
[14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
17
[49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
19
y_table = nn.Parameter(torch.from_numpy(y_table))
20
c_table = np.empty((8, 8), dtype=np.float32)
22
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
23
c_table = nn.Parameter(torch.from_numpy(c_table))
27
""" Differentiable rounding function
29
return torch.round(x) + (x - torch.round(x))**3
32
def quality_to_factor(quality):
33
""" Calculate factor corresponding to quality
36
quality(float): Quality for jpeg compression.
39
float: Compression factor.
42
quality = 5000. / quality
44
quality = 200. - quality * 2
48
# ------------------------ compression ------------------------#
49
class RGB2YCbCrJpeg(nn.Module):
50
""" Converts RGB image to YCbCr
54
super(RGB2YCbCrJpeg, self).__init__()
55
matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
57
self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
58
self.matrix = nn.Parameter(torch.from_numpy(matrix))
60
def forward(self, image):
63
image(Tensor): batch x 3 x height x width
66
Tensor: batch x height x width x 3
68
image = image.permute(0, 2, 3, 1)
69
result = torch.tensordot(image, self.matrix, dims=1) + self.shift
70
return result.view(image.shape)
73
class ChromaSubsampling(nn.Module):
74
""" Chroma subsampling on CbCr channels
78
super(ChromaSubsampling, self).__init__()
80
def forward(self, image):
83
image(tensor): batch x height x width x 3
86
y(tensor): batch x height x width
87
cb(tensor): batch x height/2 x width/2
88
cr(tensor): batch x height/2 x width/2
90
image_2 = image.permute(0, 3, 1, 2).clone()
91
cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
92
cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
93
cb = cb.permute(0, 2, 3, 1)
94
cr = cr.permute(0, 2, 3, 1)
95
return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
98
class BlockSplitting(nn.Module):
99
""" Splitting image into patches
103
super(BlockSplitting, self).__init__()
106
def forward(self, image):
109
image(tensor): batch x height x width
112
Tensor: batch x h*w/64 x h x w
114
height, _ = image.shape[1:3]
115
batch_size = image.shape[0]
116
image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
117
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
118
return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
121
class DCT8x8(nn.Module):
122
""" Discrete Cosine Transformation
126
super(DCT8x8, self).__init__()
127
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
128
for x, y, u, v in itertools.product(range(8), repeat=4):
129
tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
130
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
131
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
132
self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
134
def forward(self, image):
137
image(tensor): batch x height x width
140
Tensor: batch x height x width
143
result = self.scale * torch.tensordot(image, self.tensor, dims=2)
144
result.view(image.shape)
148
class YQuantize(nn.Module):
149
""" JPEG Quantization for Y channel
152
rounding(function): rounding function to use
155
def __init__(self, rounding):
156
super(YQuantize, self).__init__()
157
self.rounding = rounding
158
self.y_table = y_table
160
def forward(self, image, factor=1):
163
image(tensor): batch x height x width
166
Tensor: batch x height x width
168
if isinstance(factor, (int, float)):
169
image = image.float() / (self.y_table * factor)
172
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
173
image = image.float() / table
174
image = self.rounding(image)
178
class CQuantize(nn.Module):
179
""" JPEG Quantization for CbCr channels
182
rounding(function): rounding function to use
185
def __init__(self, rounding):
186
super(CQuantize, self).__init__()
187
self.rounding = rounding
188
self.c_table = c_table
190
def forward(self, image, factor=1):
193
image(tensor): batch x height x width
196
Tensor: batch x height x width
198
if isinstance(factor, (int, float)):
199
image = image.float() / (self.c_table * factor)
202
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
203
image = image.float() / table
204
image = self.rounding(image)
208
class CompressJpeg(nn.Module):
209
"""Full JPEG compression algorithm
212
rounding(function): rounding function to use
215
def __init__(self, rounding=torch.round):
216
super(CompressJpeg, self).__init__()
217
self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
218
self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
219
self.c_quantize = CQuantize(rounding=rounding)
220
self.y_quantize = YQuantize(rounding=rounding)
222
def forward(self, image, factor=1):
225
image(tensor): batch x 3 x height x width
228
dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
230
y, cb, cr = self.l1(image * 255)
231
components = {'y': y, 'cb': cb, 'cr': cr}
232
for k in components.keys():
233
comp = self.l2(components[k])
234
if k in ('cb', 'cr'):
235
comp = self.c_quantize(comp, factor=factor)
237
comp = self.y_quantize(comp, factor=factor)
241
return components['y'], components['cb'], components['cr']
244
# ------------------------ decompression ------------------------#
247
class YDequantize(nn.Module):
248
"""Dequantize Y channel
252
super(YDequantize, self).__init__()
253
self.y_table = y_table
255
def forward(self, image, factor=1):
258
image(tensor): batch x height x width
261
Tensor: batch x height x width
263
if isinstance(factor, (int, float)):
264
out = image * (self.y_table * factor)
267
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
272
class CDequantize(nn.Module):
273
"""Dequantize CbCr channel
277
super(CDequantize, self).__init__()
278
self.c_table = c_table
280
def forward(self, image, factor=1):
283
image(tensor): batch x height x width
286
Tensor: batch x height x width
288
if isinstance(factor, (int, float)):
289
out = image * (self.c_table * factor)
292
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
297
class iDCT8x8(nn.Module):
298
"""Inverse discrete Cosine Transformation
302
super(iDCT8x8, self).__init__()
303
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
304
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
305
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
306
for x, y, u, v in itertools.product(range(8), repeat=4):
307
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
308
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
310
def forward(self, image):
313
image(tensor): batch x height x width
316
Tensor: batch x height x width
318
image = image * self.alpha
319
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
320
result.view(image.shape)
324
class BlockMerging(nn.Module):
325
"""Merge patches into image
329
super(BlockMerging, self).__init__()
331
def forward(self, patches, height, width):
334
patches(tensor) batch x height*width/64, height x width
339
Tensor: batch x height x width
342
batch_size = patches.shape[0]
343
image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
344
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
345
return image_transposed.contiguous().view(batch_size, height, width)
348
class ChromaUpsampling(nn.Module):
349
"""Upsample chroma layers
353
super(ChromaUpsampling, self).__init__()
355
def forward(self, y, cb, cr):
358
y(tensor): y channel image
359
cb(tensor): cb channel
360
cr(tensor): cr channel
363
Tensor: batch x height x width x 3
367
height, width = x.shape[1:3]
369
x = x.repeat(1, 1, k, k)
370
x = x.view(-1, height * k, width * k)
375
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
378
class YCbCr2RGBJpeg(nn.Module):
379
"""Converts YCbCr image to RGB JPEG
383
super(YCbCr2RGBJpeg, self).__init__()
385
matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
386
self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
387
self.matrix = nn.Parameter(torch.from_numpy(matrix))
389
def forward(self, image):
392
image(tensor): batch x height x width x 3
395
Tensor: batch x 3 x height x width
397
result = torch.tensordot(image + self.shift, self.matrix, dims=1)
398
return result.view(image.shape).permute(0, 3, 1, 2)
401
class DeCompressJpeg(nn.Module):
402
"""Full JPEG decompression algorithm
405
rounding(function): rounding function to use
408
def __init__(self, rounding=torch.round):
409
super(DeCompressJpeg, self).__init__()
410
self.c_dequantize = CDequantize()
411
self.y_dequantize = YDequantize()
412
self.idct = iDCT8x8()
413
self.merging = BlockMerging()
414
self.chroma = ChromaUpsampling()
415
self.colors = YCbCr2RGBJpeg()
417
def forward(self, y, cb, cr, imgh, imgw, factor=1):
420
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
426
Tensor: batch x 3 x height x width
428
components = {'y': y, 'cb': cb, 'cr': cr}
429
for k in components.keys():
430
if k in ('cb', 'cr'):
431
comp = self.c_dequantize(components[k], factor=factor)
432
height, width = int(imgh / 2), int(imgw / 2)
434
comp = self.y_dequantize(components[k], factor=factor)
435
height, width = imgh, imgw
436
comp = self.idct(comp)
437
components[k] = self.merging(comp, height, width)
439
image = self.chroma(components['y'], components['cb'], components['cr'])
440
image = self.colors(image)
442
image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
446
# ------------------------ main DiffJPEG ------------------------ #
449
class DiffJPEG(nn.Module):
450
"""This JPEG algorithm result is slightly different from cv2.
451
DiffJPEG supports batch processing.
454
differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
457
def __init__(self, differentiable=True):
458
super(DiffJPEG, self).__init__()
460
rounding = diff_round
462
rounding = torch.round
464
self.compress = CompressJpeg(rounding=rounding)
465
self.decompress = DeCompressJpeg(rounding=rounding)
467
def forward(self, x, quality):
470
x (Tensor): Input image, bchw, rgb, [0, 1]
471
quality(float): Quality factor for jpeg compression scheme.
474
if isinstance(factor, (int, float)):
475
factor = quality_to_factor(factor)
477
for i in range(factor.size(0)):
478
factor[i] = quality_to_factor(factor[i])
486
x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
488
y, cb, cr = self.compress(x, factor=factor)
489
recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
490
recovered = recovered[:, :, 0:h, 0:w]
494
if __name__ == '__main__':
497
from basicsr.utils import img2tensor, tensor2img
499
img_gt = cv2.imread('test.png') / 255.
501
# -------------- cv2 -------------- #
502
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
503
_, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
504
img_lq = np.float32(cv2.imdecode(encimg, 1))
505
cv2.imwrite('cv2_JPEG_20.png', img_lq)
507
# -------------- DiffJPEG -------------- #
508
jpeger = DiffJPEG(differentiable=False).cuda()
509
img_gt = img2tensor(img_gt)
510
img_gt = torch.stack([img_gt, img_gt]).cuda()
511
quality = img_gt.new_tensor([20, 40])
512
out = jpeger(img_gt, quality=quality)
514
cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
515
cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))