google-research
145 строк · 4.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: disable=invalid-name
17"""wrappers for loss functions."""
18import lpips19import torch20from torch import nn21import torch.nn.functional as F22
23
24def adaptive_downsample256(img, mode='bilinear'):25img = img.clamp(-1, 1)26if img.shape[-1] > 256:27return F.interpolate(img, size=(256, 256), mode=mode)28else:29return img30
31
32class LPIPS_Loss(nn.Module):33"""Wrapper for LPIPS loss."""34
35def __init__(self, model='net-lin', net='vgg', use_gpu=True, spatial=False):36super(LPIPS_Loss, self).__init__()37self.model = lpips.LPIPS(net=net, spatial=spatial).eval()38
39def forward(self, pred, ref):40dist = self.model.forward(pred, ref)41assert dist.shape[2] == 1 and dist.shape[3] == 142return dist[:, :, 0, 0] # squeeze spatial dimensions43
44
45def check_loss_input(im0, im1, w):46"""im0 is out and im1 is target and w is mask."""47assert list(im0.size())[2:] == list(im1.size())[2:], 'spatial dim mismatch'48if w is not None:49assert list(im0.size())[2:] == list(w.size())[2:], 'spatial dim mismatch'50
51if im1.size(0) != 1:52assert im0.size(0) == im1.size(0)53
54if w is not None and w.size(0) != 1:55assert im0.size(0) == w.size(0)56return57
58
59# masked lpips
60class Masked_LPIPS_Loss(nn.Module):61"""LPIPS loss with spatial weighting."""62
63def __init__(self, net='vgg', device='cuda', precision='float'):64super(Masked_LPIPS_Loss, self).__init__()65self.lpips = lpips.LPIPS(net=net, spatial=True).eval()66self.lpips = self.lpips.to(device)67if precision == 'half':68self.lpips.half()69elif precision == 'float':70self.lpips.float()71elif precision == 'double':72self.lpips.double()73return74
75def forward(self, im0, im1, w=None):76"""ims have dimension BCHW while mask is B1HW."""77check_loss_input(im0, im1, w)78# lpips takes the sum of each spatial map79loss = self.lpips(im0, im1)80if w is not None:81n = torch.sum(loss * w, [1, 2, 3])82d = torch.sum(w, [1, 2, 3])83loss = n / d84return loss85
86def __call__(self, im0, im1, w=None):87return self.forward(im0, im1, w)88
89
90class Masked_L1_Loss(nn.Module):91"""L1 loss with mask."""92
93def __init__(self):94super(Masked_L1_Loss, self).__init__()95self.loss = nn.L1Loss(reduction='none')96
97def forward(self, pred, ref, w=None):98"""ims have dimension BCHW while mask is B1HW."""99check_loss_input(pred, ref, w)100loss = self.loss(pred, ref)101assert pred.shape[1] == ref.shape[1]102channels = pred.shape[1]103if w is not None:104w = w.repeat(1, channels, 1, 1) # repeat on channel wise dim105n = torch.sum(loss * w, [1, 2, 3])106d = torch.sum(w, [1, 2, 3])107loss = n / d108return loss109
110
111class L1_Loss(nn.Module):112"""Standard L1 loss, for each item in batch."""113
114def __init__(self):115super(L1_Loss, self).__init__()116self.loss = nn.L1Loss(reduction='none')117
118def forward(self, pred, ref):119"""ims have dimension BCHW."""120# output = N x 1121loss = self.loss(pred, ref)122assert pred.shape[1] == ref.shape[1]123loss = torch.mean(loss, dim=[1, 2, 3])[:, None]124return loss125
126
127class Masked_MSE_Loss(nn.Module):128"""MSE loss with masking."""129
130def __init__(self):131super(Masked_MSE_Loss, self).__init__()132self.loss = nn.MSELoss(reduction='none')133
134def forward(self, pred, ref, w=None):135"""ims have dimension BCHW while mask is B1HW."""136check_loss_input(pred, ref, w)137loss = self.loss(pred, ref)138assert pred.shape[1] == ref.shape[1]139channels = pred.shape[1]140if w is not None:141w = w.repeat(1, channels, 1, 1) # repeat on channel wise dim142n = torch.sum(loss * w, [1, 2, 3])143d = torch.sum(w, [1, 2, 3])144loss = n / d145return loss146