BasicSR

Форк
0
/
img_process_util.py 
83 строки · 2.5 Кб
1
import cv2
2
import numpy as np
3
import torch
4
from torch.nn import functional as F
5

6

7
def filter2D(img, kernel):
8
    """PyTorch version of cv2.filter2D
9

10
    Args:
11
        img (Tensor): (b, c, h, w)
12
        kernel (Tensor): (b, k, k)
13
    """
14
    k = kernel.size(-1)
15
    b, c, h, w = img.size()
16
    if k % 2 == 1:
17
        img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
18
    else:
19
        raise ValueError('Wrong kernel size')
20

21
    ph, pw = img.size()[-2:]
22

23
    if kernel.size(0) == 1:
24
        # apply the same kernel to all batch images
25
        img = img.view(b * c, 1, ph, pw)
26
        kernel = kernel.view(1, 1, k, k)
27
        return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
28
    else:
29
        img = img.view(1, b * c, ph, pw)
30
        kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
31
        return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
32

33

34
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
35
    """USM sharpening.
36

37
    Input image: I; Blurry image: B.
38
    1. sharp = I + weight * (I - B)
39
    2. Mask = 1 if abs(I - B) > threshold, else: 0
40
    3. Blur mask:
41
    4. Out = Mask * sharp + (1 - Mask) * I
42

43

44
    Args:
45
        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
46
        weight (float): Sharp weight. Default: 1.
47
        radius (float): Kernel size of Gaussian blur. Default: 50.
48
        threshold (int):
49
    """
50
    if radius % 2 == 0:
51
        radius += 1
52
    blur = cv2.GaussianBlur(img, (radius, radius), 0)
53
    residual = img - blur
54
    mask = np.abs(residual) * 255 > threshold
55
    mask = mask.astype('float32')
56
    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
57

58
    sharp = img + weight * residual
59
    sharp = np.clip(sharp, 0, 1)
60
    return soft_mask * sharp + (1 - soft_mask) * img
61

62

63
class USMSharp(torch.nn.Module):
64

65
    def __init__(self, radius=50, sigma=0):
66
        super(USMSharp, self).__init__()
67
        if radius % 2 == 0:
68
            radius += 1
69
        self.radius = radius
70
        kernel = cv2.getGaussianKernel(radius, sigma)
71
        kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
72
        self.register_buffer('kernel', kernel)
73

74
    def forward(self, img, weight=0.5, threshold=10):
75
        blur = filter2D(img, self.kernel)
76
        residual = img - blur
77

78
        mask = torch.abs(residual) * 255 > threshold
79
        mask = mask.float()
80
        soft_mask = filter2D(mask, self.kernel)
81
        sharp = img + weight * residual
82
        sharp = torch.clip(sharp, 0, 1)
83
        return soft_mask * sharp + (1 - soft_mask) * img
84

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.