lama

Форк
0
244 строки · 10.4 Кб
1
from typing import List, Tuple, Union, Optional
2

3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6

7
from saicinpainting.training.modules.base import get_conv_block_ctor, get_activation
8
from saicinpainting.training.modules.pix2pixhd import ResnetBlock
9

10

11
class ResNetHead(nn.Module):
12
    def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
13
                 padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)):
14
        assert (n_blocks >= 0)
15
        super(ResNetHead, self).__init__()
16

17
        conv_layer = get_conv_block_ctor(conv_kind)
18

19
        model = [nn.ReflectionPad2d(3),
20
                 conv_layer(input_nc, ngf, kernel_size=7, padding=0),
21
                 norm_layer(ngf),
22
                 activation]
23

24
        ### downsample
25
        for i in range(n_downsampling):
26
            mult = 2 ** i
27
            model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
28
                      norm_layer(ngf * mult * 2),
29
                      activation]
30

31
        mult = 2 ** n_downsampling
32

33
        ### resnet blocks
34
        for i in range(n_blocks):
35
            model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
36
                                  conv_kind=conv_kind)]
37

38
        self.model = nn.Sequential(*model)
39

40
    def forward(self, input):
41
        return self.model(input)
42

43

44
class ResNetTail(nn.Module):
45
    def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
46
                 padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
47
                 up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
48
                 add_in_proj=None):
49
        assert (n_blocks >= 0)
50
        super(ResNetTail, self).__init__()
51

52
        mult = 2 ** n_downsampling
53

54
        model = []
55

56
        if add_in_proj is not None:
57
            model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1))
58

59
        ### resnet blocks
60
        for i in range(n_blocks):
61
            model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
62
                                  conv_kind=conv_kind)]
63

64
        ### upsample
65
        for i in range(n_downsampling):
66
            mult = 2 ** (n_downsampling - i)
67
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
68
                                         output_padding=1),
69
                      up_norm_layer(int(ngf * mult / 2)),
70
                      up_activation]
71
        self.model = nn.Sequential(*model)
72

73
        out_layers = []
74
        for _ in range(out_extra_layers_n):
75
            out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0),
76
                           up_norm_layer(ngf),
77
                           up_activation]
78
        out_layers += [nn.ReflectionPad2d(3),
79
                       nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
80

81
        if add_out_act:
82
            out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act))
83

84
        self.out_proj = nn.Sequential(*out_layers)
85

86
    def forward(self, input, return_last_act=False):
87
        features = self.model(input)
88
        out = self.out_proj(features)
89
        if return_last_act:
90
            return out, features
91
        else:
92
            return out
93

94

95
class MultiscaleResNet(nn.Module):
96
    def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3,
97
                 norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
98
                 up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
99
                 out_cumulative=False, return_only_hr=False):
100
        super().__init__()
101

102
        self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling,
103
                                               n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type,
104
                                               conv_kind=conv_kind, activation=activation)
105
                                    for i in range(n_scales)])
106
        tail_in_feats = ngf * (2 ** n_downsampling) + ngf
107
        self.tails = nn.ModuleList([ResNetTail(output_nc,
108
                                               ngf=ngf, n_downsampling=n_downsampling,
109
                                               n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type,
110
                                               conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer,
111
                                               up_activation=up_activation, add_out_act=add_out_act,
112
                                               out_extra_layers_n=out_extra_layers_n,
113
                                               add_in_proj=None if (i == n_scales - 1) else tail_in_feats)
114
                                    for i in range(n_scales)])
115

116
        self.out_cumulative = out_cumulative
117
        self.return_only_hr = return_only_hr
118

119
    @property
120
    def num_scales(self):
121
        return len(self.heads)
122

123
    def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
124
        -> Union[torch.Tensor, List[torch.Tensor]]:
125
        """
126
        :param ms_inputs: List of inputs of different resolutions from HR to LR
127
        :param smallest_scales_num: int or None, number of smallest scales to take at input
128
        :return: Depending on return_only_hr:
129
            True: Only the most HR output
130
            False: List of outputs of different resolutions from HR to LR
131
        """
132
        if smallest_scales_num is None:
133
            assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num)
134
            smallest_scales_num = len(self.heads)
135
        else:
136
            assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num)
137

138
        cur_heads = self.heads[-smallest_scales_num:]
139
        ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)]
140

141
        all_outputs = []
142
        prev_tail_features = None
143
        for i in range(len(ms_features)):
144
            scale_i = -i - 1
145

146
            cur_tail_input = ms_features[-i - 1]
147
            if prev_tail_features is not None:
148
                if prev_tail_features.shape != cur_tail_input.shape:
149
                    prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:],
150
                                                       mode='bilinear', align_corners=False)
151
                cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1)
152

153
            cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True)
154

155
            prev_tail_features = cur_tail_feats
156
            all_outputs.append(cur_out)
157

158
        if self.out_cumulative:
159
            all_outputs_cum = [all_outputs[0]]
160
            for i in range(1, len(ms_features)):
161
                cur_out = all_outputs[i]
162
                cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:],
163
                                                      mode='bilinear', align_corners=False)
164
                all_outputs_cum.append(cur_out_cum)
165
            all_outputs = all_outputs_cum
166

167
        if self.return_only_hr:
168
            return all_outputs[-1]
169
        else:
170
            return all_outputs[::-1]
171

172

173
class MultiscaleDiscriminatorSimple(nn.Module):
174
    def __init__(self, ms_impl):
175
        super().__init__()
176
        self.ms_impl = nn.ModuleList(ms_impl)
177

178
    @property
179
    def num_scales(self):
180
        return len(self.ms_impl)
181

182
    def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
183
            -> List[Tuple[torch.Tensor, List[torch.Tensor]]]:
184
        """
185
        :param ms_inputs: List of inputs of different resolutions from HR to LR
186
        :param smallest_scales_num: int or None, number of smallest scales to take at input
187
        :return: List of pairs (prediction, features) for different resolutions from HR to LR
188
        """
189
        if smallest_scales_num is None:
190
            assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
191
            smallest_scales_num = len(self.heads)
192
        else:
193
            assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \
194
                (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
195

196
        return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)]
197

198

199
class SingleToMultiScaleInputMixin:
200
    def forward(self, x: torch.Tensor) -> List:
201
        orig_height, orig_width = x.shape[2:]
202
        factors = [2 ** i for i in range(self.num_scales)]
203
        ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False)
204
                     for f in factors]
205
        return super().forward(ms_inputs)
206

207

208
class GeneratorMultiToSingleOutputMixin:
209
    def forward(self, x):
210
        return super().forward(x)[0]
211

212

213
class DiscriminatorMultiToSingleOutputMixin:
214
    def forward(self, x):
215
        out_feat_tuples = super().forward(x)
216
        return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist]
217

218

219
class DiscriminatorMultiToSingleOutputStackedMixin:
220
    def __init__(self, *args, return_feats_only_levels=None, **kwargs):
221
        super().__init__(*args, **kwargs)
222
        self.return_feats_only_levels = return_feats_only_levels
223

224
    def forward(self, x):
225
        out_feat_tuples = super().forward(x)
226
        outs = [out for out, _ in out_feat_tuples]
227
        scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:],
228
                                                 mode='bilinear', align_corners=False)
229
                                   for cur_out in outs[1:]]
230
        out = torch.cat(scaled_outs, dim=1)
231
        if self.return_feats_only_levels is not None:
232
            feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels]
233
        else:
234
            feat_lists = [flist for _, flist in out_feat_tuples]
235
        feats = [f for flist in feat_lists for f in flist]
236
        return out, feats
237

238

239
class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple):
240
    pass
241

242

243
class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet):
244
    pass
245

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.