google-research

Форк
0
145 строк · 4.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: disable=invalid-name
17
"""wrappers for loss functions."""
18
import lpips
19
import torch
20
from torch import nn
21
import torch.nn.functional as F
22

23

24
def adaptive_downsample256(img, mode='bilinear'):
25
  img = img.clamp(-1, 1)
26
  if img.shape[-1] > 256:
27
    return F.interpolate(img, size=(256, 256), mode=mode)
28
  else:
29
    return img
30

31

32
class LPIPS_Loss(nn.Module):
33
  """Wrapper for LPIPS loss."""
34

35
  def __init__(self, model='net-lin', net='vgg', use_gpu=True, spatial=False):
36
    super(LPIPS_Loss, self).__init__()
37
    self.model = lpips.LPIPS(net=net, spatial=spatial).eval()
38

39
  def forward(self, pred, ref):
40
    dist = self.model.forward(pred, ref)
41
    assert dist.shape[2] == 1 and dist.shape[3] == 1
42
    return dist[:, :, 0, 0]  # squeeze spatial dimensions
43

44

45
def check_loss_input(im0, im1, w):
46
  """im0 is out and im1 is target and w is mask."""
47
  assert list(im0.size())[2:] == list(im1.size())[2:], 'spatial dim mismatch'
48
  if w is not None:
49
    assert list(im0.size())[2:] == list(w.size())[2:], 'spatial dim mismatch'
50

51
  if im1.size(0) != 1:
52
    assert im0.size(0) == im1.size(0)
53

54
  if w is not None and w.size(0) != 1:
55
    assert im0.size(0) == w.size(0)
56
  return
57

58

59
# masked lpips
60
class Masked_LPIPS_Loss(nn.Module):
61
  """LPIPS loss with spatial weighting."""
62

63
  def __init__(self, net='vgg', device='cuda', precision='float'):
64
    super(Masked_LPIPS_Loss, self).__init__()
65
    self.lpips = lpips.LPIPS(net=net, spatial=True).eval()
66
    self.lpips = self.lpips.to(device)
67
    if precision == 'half':
68
      self.lpips.half()
69
    elif precision == 'float':
70
      self.lpips.float()
71
    elif precision == 'double':
72
      self.lpips.double()
73
    return
74

75
  def forward(self, im0, im1, w=None):
76
    """ims have dimension BCHW while mask is B1HW."""
77
    check_loss_input(im0, im1, w)
78
    # lpips takes the sum of each spatial map
79
    loss = self.lpips(im0, im1)
80
    if w is not None:
81
      n = torch.sum(loss * w, [1, 2, 3])
82
      d = torch.sum(w, [1, 2, 3])
83
      loss = n / d
84
    return loss
85

86
  def __call__(self, im0, im1, w=None):
87
    return self.forward(im0, im1, w)
88

89

90
class Masked_L1_Loss(nn.Module):
91
  """L1 loss with mask."""
92

93
  def __init__(self):
94
    super(Masked_L1_Loss, self).__init__()
95
    self.loss = nn.L1Loss(reduction='none')
96

97
  def forward(self, pred, ref, w=None):
98
    """ims have dimension BCHW while mask is B1HW."""
99
    check_loss_input(pred, ref, w)
100
    loss = self.loss(pred, ref)
101
    assert pred.shape[1] == ref.shape[1]
102
    channels = pred.shape[1]
103
    if w is not None:
104
      w = w.repeat(1, channels, 1, 1)  # repeat on channel wise dim
105
      n = torch.sum(loss * w, [1, 2, 3])
106
      d = torch.sum(w, [1, 2, 3])
107
      loss = n / d
108
    return loss
109

110

111
class L1_Loss(nn.Module):
112
  """Standard L1 loss, for each item in batch."""
113

114
  def __init__(self):
115
    super(L1_Loss, self).__init__()
116
    self.loss = nn.L1Loss(reduction='none')
117

118
  def forward(self, pred, ref):
119
    """ims have dimension BCHW."""
120
    # output = N x 1
121
    loss = self.loss(pred, ref)
122
    assert pred.shape[1] == ref.shape[1]
123
    loss = torch.mean(loss, dim=[1, 2, 3])[:, None]
124
    return loss
125

126

127
class Masked_MSE_Loss(nn.Module):
128
  """MSE loss with masking."""
129

130
  def __init__(self):
131
    super(Masked_MSE_Loss, self).__init__()
132
    self.loss = nn.MSELoss(reduction='none')
133

134
  def forward(self, pred, ref, w=None):
135
    """ims have dimension BCHW while mask is B1HW."""
136
    check_loss_input(pred, ref, w)
137
    loss = self.loss(pred, ref)
138
    assert pred.shape[1] == ref.shape[1]
139
    channels = pred.shape[1]
140
    if w is not None:
141
      w = w.repeat(1, channels, 1, 1)  # repeat on channel wise dim
142
      n = torch.sum(loss * w, [1, 2, 3])
143
      d = torch.sum(w, [1, 2, 3])
144
      loss = n / d
145
    return loss
146

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.