BasicSR

Форк
0
/
basic_loss.py 
253 строки · 9.0 Кб
1
import torch
2
from torch import nn as nn
3
from torch.nn import functional as F
4

5
from basicsr.archs.vgg_arch import VGGFeatureExtractor
6
from basicsr.utils.registry import LOSS_REGISTRY
7
from .loss_util import weighted_loss
8

9
_reduction_modes = ['none', 'mean', 'sum']
10

11

12
@weighted_loss
13
def l1_loss(pred, target):
14
    return F.l1_loss(pred, target, reduction='none')
15

16

17
@weighted_loss
18
def mse_loss(pred, target):
19
    return F.mse_loss(pred, target, reduction='none')
20

21

22
@weighted_loss
23
def charbonnier_loss(pred, target, eps=1e-12):
24
    return torch.sqrt((pred - target)**2 + eps)
25

26

27
@LOSS_REGISTRY.register()
28
class L1Loss(nn.Module):
29
    """L1 (mean absolute error, MAE) loss.
30

31
    Args:
32
        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
33
        reduction (str): Specifies the reduction to apply to the output.
34
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
35
    """
36

37
    def __init__(self, loss_weight=1.0, reduction='mean'):
38
        super(L1Loss, self).__init__()
39
        if reduction not in ['none', 'mean', 'sum']:
40
            raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
41

42
        self.loss_weight = loss_weight
43
        self.reduction = reduction
44

45
    def forward(self, pred, target, weight=None, **kwargs):
46
        """
47
        Args:
48
            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
49
            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
50
            weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
51
        """
52
        return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
53

54

55
@LOSS_REGISTRY.register()
56
class MSELoss(nn.Module):
57
    """MSE (L2) loss.
58

59
    Args:
60
        loss_weight (float): Loss weight for MSE loss. Default: 1.0.
61
        reduction (str): Specifies the reduction to apply to the output.
62
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
63
    """
64

65
    def __init__(self, loss_weight=1.0, reduction='mean'):
66
        super(MSELoss, self).__init__()
67
        if reduction not in ['none', 'mean', 'sum']:
68
            raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
69

70
        self.loss_weight = loss_weight
71
        self.reduction = reduction
72

73
    def forward(self, pred, target, weight=None, **kwargs):
74
        """
75
        Args:
76
            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
77
            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
78
            weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
79
        """
80
        return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
81

82

83
@LOSS_REGISTRY.register()
84
class CharbonnierLoss(nn.Module):
85
    """Charbonnier loss (one variant of Robust L1Loss, a differentiable
86
    variant of L1Loss).
87

88
    Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
89
        Super-Resolution".
90

91
    Args:
92
        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
93
        reduction (str): Specifies the reduction to apply to the output.
94
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
95
        eps (float): A value used to control the curvature near zero. Default: 1e-12.
96
    """
97

98
    def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
99
        super(CharbonnierLoss, self).__init__()
100
        if reduction not in ['none', 'mean', 'sum']:
101
            raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
102

103
        self.loss_weight = loss_weight
104
        self.reduction = reduction
105
        self.eps = eps
106

107
    def forward(self, pred, target, weight=None, **kwargs):
108
        """
109
        Args:
110
            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
111
            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
112
            weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
113
        """
114
        return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
115

116

117
@LOSS_REGISTRY.register()
118
class WeightedTVLoss(L1Loss):
119
    """Weighted TV loss.
120

121
    Args:
122
        loss_weight (float): Loss weight. Default: 1.0.
123
    """
124

125
    def __init__(self, loss_weight=1.0, reduction='mean'):
126
        if reduction not in ['mean', 'sum']:
127
            raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
128
        super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)
129

130
    def forward(self, pred, weight=None):
131
        if weight is None:
132
            y_weight = None
133
            x_weight = None
134
        else:
135
            y_weight = weight[:, :, :-1, :]
136
            x_weight = weight[:, :, :, :-1]
137

138
        y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
139
        x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
140

141
        loss = x_diff + y_diff
142

143
        return loss
144

145

146
@LOSS_REGISTRY.register()
147
class PerceptualLoss(nn.Module):
148
    """Perceptual loss with commonly used style loss.
149

150
    Args:
151
        layer_weights (dict): The weight for each layer of vgg feature.
152
            Here is an example: {'conv5_4': 1.}, which means the conv5_4
153
            feature layer (before relu5_4) will be extracted with weight
154
            1.0 in calculating losses.
155
        vgg_type (str): The type of vgg network used as feature extractor.
156
            Default: 'vgg19'.
157
        use_input_norm (bool):  If True, normalize the input image in vgg.
158
            Default: True.
159
        range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
160
            Default: False.
161
        perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
162
            loss will be calculated and the loss will multiplied by the
163
            weight. Default: 1.0.
164
        style_weight (float): If `style_weight > 0`, the style loss will be
165
            calculated and the loss will multiplied by the weight.
166
            Default: 0.
167
        criterion (str): Criterion used for perceptual loss. Default: 'l1'.
168
    """
169

170
    def __init__(self,
171
                 layer_weights,
172
                 vgg_type='vgg19',
173
                 use_input_norm=True,
174
                 range_norm=False,
175
                 perceptual_weight=1.0,
176
                 style_weight=0.,
177
                 criterion='l1'):
178
        super(PerceptualLoss, self).__init__()
179
        self.perceptual_weight = perceptual_weight
180
        self.style_weight = style_weight
181
        self.layer_weights = layer_weights
182
        self.vgg = VGGFeatureExtractor(
183
            layer_name_list=list(layer_weights.keys()),
184
            vgg_type=vgg_type,
185
            use_input_norm=use_input_norm,
186
            range_norm=range_norm)
187

188
        self.criterion_type = criterion
189
        if self.criterion_type == 'l1':
190
            self.criterion = torch.nn.L1Loss()
191
        elif self.criterion_type == 'l2':
192
            self.criterion = torch.nn.MSELoss()
193
        elif self.criterion_type == 'fro':
194
            self.criterion = None
195
        else:
196
            raise NotImplementedError(f'{criterion} criterion has not been supported.')
197

198
    def forward(self, x, gt):
199
        """Forward function.
200

201
        Args:
202
            x (Tensor): Input tensor with shape (n, c, h, w).
203
            gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
204

205
        Returns:
206
            Tensor: Forward results.
207
        """
208
        # extract vgg features
209
        x_features = self.vgg(x)
210
        gt_features = self.vgg(gt.detach())
211

212
        # calculate perceptual loss
213
        if self.perceptual_weight > 0:
214
            percep_loss = 0
215
            for k in x_features.keys():
216
                if self.criterion_type == 'fro':
217
                    percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
218
                else:
219
                    percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
220
            percep_loss *= self.perceptual_weight
221
        else:
222
            percep_loss = None
223

224
        # calculate style loss
225
        if self.style_weight > 0:
226
            style_loss = 0
227
            for k in x_features.keys():
228
                if self.criterion_type == 'fro':
229
                    style_loss += torch.norm(
230
                        self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
231
                else:
232
                    style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
233
                        gt_features[k])) * self.layer_weights[k]
234
            style_loss *= self.style_weight
235
        else:
236
            style_loss = None
237

238
        return percep_loss, style_loss
239

240
    def _gram_mat(self, x):
241
        """Calculate Gram matrix.
242

243
        Args:
244
            x (torch.Tensor): Tensor with shape of (n, c, h, w).
245

246
        Returns:
247
            torch.Tensor: Gram matrix.
248
        """
249
        n, c, h, w = x.size()
250
        features = x.view(n, c, w * h)
251
        features_t = features.transpose(1, 2)
252
        gram = features.bmm(features_t) / (c * h * w)
253
        return gram
254

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.