3
from torch.nn import functional as F
6
def reduce_loss(loss, reduction):
7
"""Reduce loss as specified.
10
loss (Tensor): Elementwise loss tensor.
11
reduction (str): Options are 'none', 'mean' and 'sum'.
14
Tensor: Reduced loss tensor.
16
reduction_enum = F._Reduction.get_enum(reduction)
18
if reduction_enum == 0:
20
elif reduction_enum == 1:
26
def weight_reduce_loss(loss, weight=None, reduction='mean'):
27
"""Apply element-wise weight and reduce loss.
30
loss (Tensor): Element-wise loss.
31
weight (Tensor): Element-wise weights. Default: None.
32
reduction (str): Same as built-in losses of PyTorch. Options are
33
'none', 'mean' and 'sum'. Default: 'mean'.
39
if weight is not None:
40
assert weight.dim() == loss.dim()
41
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
45
if weight is None or reduction == 'sum':
46
loss = reduce_loss(loss, reduction)
48
elif reduction == 'mean':
49
if weight.size(1) > 1:
52
weight = weight.sum() * loss.size(1)
53
loss = loss.sum() / weight
58
def weighted_loss(loss_func):
59
"""Create a weighted version of a given loss function.
61
To use this decorator, the loss function must have the signature like
62
`loss_func(pred, target, **kwargs)`. The function only needs to compute
63
element-wise loss without any reduction. This decorator will add weight
64
and reduction arguments to the function. The decorated function will have
65
the signature like `loss_func(pred, target, weight=None, reduction='mean',
72
>>> def l1_loss(pred, target):
73
>>> return (pred - target).abs()
75
>>> pred = torch.Tensor([0, 2, 3])
76
>>> target = torch.Tensor([1, 1, 1])
77
>>> weight = torch.Tensor([1, 0, 1])
79
>>> l1_loss(pred, target)
81
>>> l1_loss(pred, target, weight)
83
>>> l1_loss(pred, target, reduction='none')
85
>>> l1_loss(pred, target, weight, reduction='sum')
89
@functools.wraps(loss_func)
90
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
92
loss = loss_func(pred, target, **kwargs)
93
loss = weight_reduce_loss(loss, weight, reduction)
99
def get_local_weights(residual, ksize):
100
"""Get local weights for generating the artifact map of LDL.
102
It is only called by the `get_refined_artifact_map` function.
105
residual (Tensor): Residual between predicted and ground truth images.
106
ksize (Int): size of the local window.
109
Tensor: weight for each pixel to be discriminated as an artifact pixel
112
pad = (ksize - 1) // 2
113
residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
115
unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
116
pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
118
return pixel_level_weight
121
def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
122
"""Calculate the artifact map of LDL
123
(Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
126
img_gt (Tensor): ground truth images.
127
img_output (Tensor): output images given by the optimizing model.
128
img_ema (Tensor): output images given by the ema model.
129
ksize (Int): size of the local window.
132
overall_weight: weight for each pixel to be discriminated as an artifact pixel
133
(calculated based on both local and global observations).
136
residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
137
residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
139
patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
140
pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
141
overall_weight = patch_level_weight * pixel_level_weight
143
overall_weight[residual_sr < residual_ema] = 0
145
return overall_weight