BasicSR

Форк
0
/
diffjpeg.py 
515 строк · 15.3 Кб
1
"""
2
Modified from https://github.com/mlomnitz/DiffJPEG
3

4
For images not divisible by 8
5
https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
6
"""
7
import itertools
8
import numpy as np
9
import torch
10
import torch.nn as nn
11
from torch.nn import functional as F
12

13
# ------------------------ utils ------------------------#
14
y_table = np.array(
15
    [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
16
     [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
17
     [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
18
    dtype=np.float32).T
19
y_table = nn.Parameter(torch.from_numpy(y_table))
20
c_table = np.empty((8, 8), dtype=np.float32)
21
c_table.fill(99)
22
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
23
c_table = nn.Parameter(torch.from_numpy(c_table))
24

25

26
def diff_round(x):
27
    """ Differentiable rounding function
28
    """
29
    return torch.round(x) + (x - torch.round(x))**3
30

31

32
def quality_to_factor(quality):
33
    """ Calculate factor corresponding to quality
34

35
    Args:
36
        quality(float): Quality for jpeg compression.
37

38
    Returns:
39
        float: Compression factor.
40
    """
41
    if quality < 50:
42
        quality = 5000. / quality
43
    else:
44
        quality = 200. - quality * 2
45
    return quality / 100.
46

47

48
# ------------------------ compression ------------------------#
49
class RGB2YCbCrJpeg(nn.Module):
50
    """ Converts RGB image to YCbCr
51
    """
52

53
    def __init__(self):
54
        super(RGB2YCbCrJpeg, self).__init__()
55
        matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
56
                          dtype=np.float32).T
57
        self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
58
        self.matrix = nn.Parameter(torch.from_numpy(matrix))
59

60
    def forward(self, image):
61
        """
62
        Args:
63
            image(Tensor): batch x 3 x height x width
64

65
        Returns:
66
            Tensor: batch x height x width x 3
67
        """
68
        image = image.permute(0, 2, 3, 1)
69
        result = torch.tensordot(image, self.matrix, dims=1) + self.shift
70
        return result.view(image.shape)
71

72

73
class ChromaSubsampling(nn.Module):
74
    """ Chroma subsampling on CbCr channels
75
    """
76

77
    def __init__(self):
78
        super(ChromaSubsampling, self).__init__()
79

80
    def forward(self, image):
81
        """
82
        Args:
83
            image(tensor): batch x height x width x 3
84

85
        Returns:
86
            y(tensor): batch x height x width
87
            cb(tensor): batch x height/2 x width/2
88
            cr(tensor): batch x height/2 x width/2
89
        """
90
        image_2 = image.permute(0, 3, 1, 2).clone()
91
        cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
92
        cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
93
        cb = cb.permute(0, 2, 3, 1)
94
        cr = cr.permute(0, 2, 3, 1)
95
        return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
96

97

98
class BlockSplitting(nn.Module):
99
    """ Splitting image into patches
100
    """
101

102
    def __init__(self):
103
        super(BlockSplitting, self).__init__()
104
        self.k = 8
105

106
    def forward(self, image):
107
        """
108
        Args:
109
            image(tensor): batch x height x width
110

111
        Returns:
112
            Tensor:  batch x h*w/64 x h x w
113
        """
114
        height, _ = image.shape[1:3]
115
        batch_size = image.shape[0]
116
        image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
117
        image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
118
        return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
119

120

121
class DCT8x8(nn.Module):
122
    """ Discrete Cosine Transformation
123
    """
124

125
    def __init__(self):
126
        super(DCT8x8, self).__init__()
127
        tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
128
        for x, y, u, v in itertools.product(range(8), repeat=4):
129
            tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
130
        alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
131
        self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
132
        self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
133

134
    def forward(self, image):
135
        """
136
        Args:
137
            image(tensor): batch x height x width
138

139
        Returns:
140
            Tensor: batch x height x width
141
        """
142
        image = image - 128
143
        result = self.scale * torch.tensordot(image, self.tensor, dims=2)
144
        result.view(image.shape)
145
        return result
146

147

148
class YQuantize(nn.Module):
149
    """ JPEG Quantization for Y channel
150

151
    Args:
152
        rounding(function): rounding function to use
153
    """
154

155
    def __init__(self, rounding):
156
        super(YQuantize, self).__init__()
157
        self.rounding = rounding
158
        self.y_table = y_table
159

160
    def forward(self, image, factor=1):
161
        """
162
        Args:
163
            image(tensor): batch x height x width
164

165
        Returns:
166
            Tensor: batch x height x width
167
        """
168
        if isinstance(factor, (int, float)):
169
            image = image.float() / (self.y_table * factor)
170
        else:
171
            b = factor.size(0)
172
            table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
173
            image = image.float() / table
174
        image = self.rounding(image)
175
        return image
176

177

178
class CQuantize(nn.Module):
179
    """ JPEG Quantization for CbCr channels
180

181
    Args:
182
        rounding(function): rounding function to use
183
    """
184

185
    def __init__(self, rounding):
186
        super(CQuantize, self).__init__()
187
        self.rounding = rounding
188
        self.c_table = c_table
189

190
    def forward(self, image, factor=1):
191
        """
192
        Args:
193
            image(tensor): batch x height x width
194

195
        Returns:
196
            Tensor: batch x height x width
197
        """
198
        if isinstance(factor, (int, float)):
199
            image = image.float() / (self.c_table * factor)
200
        else:
201
            b = factor.size(0)
202
            table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
203
            image = image.float() / table
204
        image = self.rounding(image)
205
        return image
206

207

208
class CompressJpeg(nn.Module):
209
    """Full JPEG compression algorithm
210

211
    Args:
212
        rounding(function): rounding function to use
213
    """
214

215
    def __init__(self, rounding=torch.round):
216
        super(CompressJpeg, self).__init__()
217
        self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
218
        self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
219
        self.c_quantize = CQuantize(rounding=rounding)
220
        self.y_quantize = YQuantize(rounding=rounding)
221

222
    def forward(self, image, factor=1):
223
        """
224
        Args:
225
            image(tensor): batch x 3 x height x width
226

227
        Returns:
228
            dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
229
        """
230
        y, cb, cr = self.l1(image * 255)
231
        components = {'y': y, 'cb': cb, 'cr': cr}
232
        for k in components.keys():
233
            comp = self.l2(components[k])
234
            if k in ('cb', 'cr'):
235
                comp = self.c_quantize(comp, factor=factor)
236
            else:
237
                comp = self.y_quantize(comp, factor=factor)
238

239
            components[k] = comp
240

241
        return components['y'], components['cb'], components['cr']
242

243

244
# ------------------------ decompression ------------------------#
245

246

247
class YDequantize(nn.Module):
248
    """Dequantize Y channel
249
    """
250

251
    def __init__(self):
252
        super(YDequantize, self).__init__()
253
        self.y_table = y_table
254

255
    def forward(self, image, factor=1):
256
        """
257
        Args:
258
            image(tensor): batch x height x width
259

260
        Returns:
261
            Tensor: batch x height x width
262
        """
263
        if isinstance(factor, (int, float)):
264
            out = image * (self.y_table * factor)
265
        else:
266
            b = factor.size(0)
267
            table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
268
            out = image * table
269
        return out
270

271

272
class CDequantize(nn.Module):
273
    """Dequantize CbCr channel
274
    """
275

276
    def __init__(self):
277
        super(CDequantize, self).__init__()
278
        self.c_table = c_table
279

280
    def forward(self, image, factor=1):
281
        """
282
        Args:
283
            image(tensor): batch x height x width
284

285
        Returns:
286
            Tensor: batch x height x width
287
        """
288
        if isinstance(factor, (int, float)):
289
            out = image * (self.c_table * factor)
290
        else:
291
            b = factor.size(0)
292
            table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
293
            out = image * table
294
        return out
295

296

297
class iDCT8x8(nn.Module):
298
    """Inverse discrete Cosine Transformation
299
    """
300

301
    def __init__(self):
302
        super(iDCT8x8, self).__init__()
303
        alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
304
        self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
305
        tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
306
        for x, y, u, v in itertools.product(range(8), repeat=4):
307
            tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
308
        self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
309

310
    def forward(self, image):
311
        """
312
        Args:
313
            image(tensor): batch x height x width
314

315
        Returns:
316
            Tensor: batch x height x width
317
        """
318
        image = image * self.alpha
319
        result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
320
        result.view(image.shape)
321
        return result
322

323

324
class BlockMerging(nn.Module):
325
    """Merge patches into image
326
    """
327

328
    def __init__(self):
329
        super(BlockMerging, self).__init__()
330

331
    def forward(self, patches, height, width):
332
        """
333
        Args:
334
            patches(tensor) batch x height*width/64, height x width
335
            height(int)
336
            width(int)
337

338
        Returns:
339
            Tensor: batch x height x width
340
        """
341
        k = 8
342
        batch_size = patches.shape[0]
343
        image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
344
        image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
345
        return image_transposed.contiguous().view(batch_size, height, width)
346

347

348
class ChromaUpsampling(nn.Module):
349
    """Upsample chroma layers
350
    """
351

352
    def __init__(self):
353
        super(ChromaUpsampling, self).__init__()
354

355
    def forward(self, y, cb, cr):
356
        """
357
        Args:
358
            y(tensor): y channel image
359
            cb(tensor): cb channel
360
            cr(tensor): cr channel
361

362
        Returns:
363
            Tensor: batch x height x width x 3
364
        """
365

366
        def repeat(x, k=2):
367
            height, width = x.shape[1:3]
368
            x = x.unsqueeze(-1)
369
            x = x.repeat(1, 1, k, k)
370
            x = x.view(-1, height * k, width * k)
371
            return x
372

373
        cb = repeat(cb)
374
        cr = repeat(cr)
375
        return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
376

377

378
class YCbCr2RGBJpeg(nn.Module):
379
    """Converts YCbCr image to RGB JPEG
380
    """
381

382
    def __init__(self):
383
        super(YCbCr2RGBJpeg, self).__init__()
384

385
        matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
386
        self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
387
        self.matrix = nn.Parameter(torch.from_numpy(matrix))
388

389
    def forward(self, image):
390
        """
391
        Args:
392
            image(tensor): batch x height x width x 3
393

394
        Returns:
395
            Tensor: batch x 3 x height x width
396
        """
397
        result = torch.tensordot(image + self.shift, self.matrix, dims=1)
398
        return result.view(image.shape).permute(0, 3, 1, 2)
399

400

401
class DeCompressJpeg(nn.Module):
402
    """Full JPEG decompression algorithm
403

404
    Args:
405
        rounding(function): rounding function to use
406
    """
407

408
    def __init__(self, rounding=torch.round):
409
        super(DeCompressJpeg, self).__init__()
410
        self.c_dequantize = CDequantize()
411
        self.y_dequantize = YDequantize()
412
        self.idct = iDCT8x8()
413
        self.merging = BlockMerging()
414
        self.chroma = ChromaUpsampling()
415
        self.colors = YCbCr2RGBJpeg()
416

417
    def forward(self, y, cb, cr, imgh, imgw, factor=1):
418
        """
419
        Args:
420
            compressed(dict(tensor)): batch x h*w/64 x 8 x 8
421
            imgh(int)
422
            imgw(int)
423
            factor(float)
424

425
        Returns:
426
            Tensor: batch x 3 x height x width
427
        """
428
        components = {'y': y, 'cb': cb, 'cr': cr}
429
        for k in components.keys():
430
            if k in ('cb', 'cr'):
431
                comp = self.c_dequantize(components[k], factor=factor)
432
                height, width = int(imgh / 2), int(imgw / 2)
433
            else:
434
                comp = self.y_dequantize(components[k], factor=factor)
435
                height, width = imgh, imgw
436
            comp = self.idct(comp)
437
            components[k] = self.merging(comp, height, width)
438
            #
439
        image = self.chroma(components['y'], components['cb'], components['cr'])
440
        image = self.colors(image)
441

442
        image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
443
        return image / 255
444

445

446
# ------------------------ main DiffJPEG ------------------------ #
447

448

449
class DiffJPEG(nn.Module):
450
    """This JPEG algorithm result is slightly different from cv2.
451
    DiffJPEG supports batch processing.
452

453
    Args:
454
        differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
455
    """
456

457
    def __init__(self, differentiable=True):
458
        super(DiffJPEG, self).__init__()
459
        if differentiable:
460
            rounding = diff_round
461
        else:
462
            rounding = torch.round
463

464
        self.compress = CompressJpeg(rounding=rounding)
465
        self.decompress = DeCompressJpeg(rounding=rounding)
466

467
    def forward(self, x, quality):
468
        """
469
        Args:
470
            x (Tensor): Input image, bchw, rgb, [0, 1]
471
            quality(float): Quality factor for jpeg compression scheme.
472
        """
473
        factor = quality
474
        if isinstance(factor, (int, float)):
475
            factor = quality_to_factor(factor)
476
        else:
477
            for i in range(factor.size(0)):
478
                factor[i] = quality_to_factor(factor[i])
479
        h, w = x.size()[-2:]
480
        h_pad, w_pad = 0, 0
481
        # why should use 16
482
        if h % 16 != 0:
483
            h_pad = 16 - h % 16
484
        if w % 16 != 0:
485
            w_pad = 16 - w % 16
486
        x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
487

488
        y, cb, cr = self.compress(x, factor=factor)
489
        recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
490
        recovered = recovered[:, :, 0:h, 0:w]
491
        return recovered
492

493

494
if __name__ == '__main__':
495
    import cv2
496

497
    from basicsr.utils import img2tensor, tensor2img
498

499
    img_gt = cv2.imread('test.png') / 255.
500

501
    # -------------- cv2 -------------- #
502
    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
503
    _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
504
    img_lq = np.float32(cv2.imdecode(encimg, 1))
505
    cv2.imwrite('cv2_JPEG_20.png', img_lq)
506

507
    # -------------- DiffJPEG -------------- #
508
    jpeger = DiffJPEG(differentiable=False).cuda()
509
    img_gt = img2tensor(img_gt)
510
    img_gt = torch.stack([img_gt, img_gt]).cuda()
511
    quality = img_gt.new_tensor([20, 40])
512
    out = jpeger(img_gt, quality=quality)
513

514
    cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
515
    cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
516

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.