BasicSR

Форк
0
/
loss_util.py 
145 строк · 4.7 Кб
1
import functools
2
import torch
3
from torch.nn import functional as F
4

5

6
def reduce_loss(loss, reduction):
7
    """Reduce loss as specified.
8

9
    Args:
10
        loss (Tensor): Elementwise loss tensor.
11
        reduction (str): Options are 'none', 'mean' and 'sum'.
12

13
    Returns:
14
        Tensor: Reduced loss tensor.
15
    """
16
    reduction_enum = F._Reduction.get_enum(reduction)
17
    # none: 0, elementwise_mean:1, sum: 2
18
    if reduction_enum == 0:
19
        return loss
20
    elif reduction_enum == 1:
21
        return loss.mean()
22
    else:
23
        return loss.sum()
24

25

26
def weight_reduce_loss(loss, weight=None, reduction='mean'):
27
    """Apply element-wise weight and reduce loss.
28

29
    Args:
30
        loss (Tensor): Element-wise loss.
31
        weight (Tensor): Element-wise weights. Default: None.
32
        reduction (str): Same as built-in losses of PyTorch. Options are
33
            'none', 'mean' and 'sum'. Default: 'mean'.
34

35
    Returns:
36
        Tensor: Loss values.
37
    """
38
    # if weight is specified, apply element-wise weight
39
    if weight is not None:
40
        assert weight.dim() == loss.dim()
41
        assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
42
        loss = loss * weight
43

44
    # if weight is not specified or reduction is sum, just reduce the loss
45
    if weight is None or reduction == 'sum':
46
        loss = reduce_loss(loss, reduction)
47
    # if reduction is mean, then compute mean over weight region
48
    elif reduction == 'mean':
49
        if weight.size(1) > 1:
50
            weight = weight.sum()
51
        else:
52
            weight = weight.sum() * loss.size(1)
53
        loss = loss.sum() / weight
54

55
    return loss
56

57

58
def weighted_loss(loss_func):
59
    """Create a weighted version of a given loss function.
60

61
    To use this decorator, the loss function must have the signature like
62
    `loss_func(pred, target, **kwargs)`. The function only needs to compute
63
    element-wise loss without any reduction. This decorator will add weight
64
    and reduction arguments to the function. The decorated function will have
65
    the signature like `loss_func(pred, target, weight=None, reduction='mean',
66
    **kwargs)`.
67

68
    :Example:
69

70
    >>> import torch
71
    >>> @weighted_loss
72
    >>> def l1_loss(pred, target):
73
    >>>     return (pred - target).abs()
74

75
    >>> pred = torch.Tensor([0, 2, 3])
76
    >>> target = torch.Tensor([1, 1, 1])
77
    >>> weight = torch.Tensor([1, 0, 1])
78

79
    >>> l1_loss(pred, target)
80
    tensor(1.3333)
81
    >>> l1_loss(pred, target, weight)
82
    tensor(1.5000)
83
    >>> l1_loss(pred, target, reduction='none')
84
    tensor([1., 1., 2.])
85
    >>> l1_loss(pred, target, weight, reduction='sum')
86
    tensor(3.)
87
    """
88

89
    @functools.wraps(loss_func)
90
    def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
91
        # get element-wise loss
92
        loss = loss_func(pred, target, **kwargs)
93
        loss = weight_reduce_loss(loss, weight, reduction)
94
        return loss
95

96
    return wrapper
97

98

99
def get_local_weights(residual, ksize):
100
    """Get local weights for generating the artifact map of LDL.
101

102
    It is only called by the `get_refined_artifact_map` function.
103

104
    Args:
105
        residual (Tensor): Residual between predicted and ground truth images.
106
        ksize (Int): size of the local window.
107

108
    Returns:
109
        Tensor: weight for each pixel to be discriminated as an artifact pixel
110
    """
111

112
    pad = (ksize - 1) // 2
113
    residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
114

115
    unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
116
    pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
117

118
    return pixel_level_weight
119

120

121
def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
122
    """Calculate the artifact map of LDL
123
    (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
124

125
    Args:
126
        img_gt (Tensor): ground truth images.
127
        img_output (Tensor): output images given by the optimizing model.
128
        img_ema (Tensor): output images given by the ema model.
129
        ksize (Int): size of the local window.
130

131
    Returns:
132
        overall_weight: weight for each pixel to be discriminated as an artifact pixel
133
        (calculated based on both local and global observations).
134
    """
135

136
    residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
137
    residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
138

139
    patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
140
    pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
141
    overall_weight = patch_level_weight * pixel_level_weight
142

143
    overall_weight[residual_sr < residual_ema] = 0
144

145
    return overall_weight
146

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.