GFPGAN

Форк
0
/
arcface_arch.py 
245 строк · 7.9 Кб
1
import torch.nn as nn
2
from basicsr.utils.registry import ARCH_REGISTRY
3

4

5
def conv3x3(inplanes, outplanes, stride=1):
6
    """A simple wrapper for 3x3 convolution with padding.
7

8
    Args:
9
        inplanes (int): Channel number of inputs.
10
        outplanes (int): Channel number of outputs.
11
        stride (int): Stride in convolution. Default: 1.
12
    """
13
    return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
14

15

16
class BasicBlock(nn.Module):
17
    """Basic residual block used in the ResNetArcFace architecture.
18

19
    Args:
20
        inplanes (int): Channel number of inputs.
21
        planes (int): Channel number of outputs.
22
        stride (int): Stride in convolution. Default: 1.
23
        downsample (nn.Module): The downsample module. Default: None.
24
    """
25
    expansion = 1  # output channel expansion ratio
26

27
    def __init__(self, inplanes, planes, stride=1, downsample=None):
28
        super(BasicBlock, self).__init__()
29
        self.conv1 = conv3x3(inplanes, planes, stride)
30
        self.bn1 = nn.BatchNorm2d(planes)
31
        self.relu = nn.ReLU(inplace=True)
32
        self.conv2 = conv3x3(planes, planes)
33
        self.bn2 = nn.BatchNorm2d(planes)
34
        self.downsample = downsample
35
        self.stride = stride
36

37
    def forward(self, x):
38
        residual = x
39

40
        out = self.conv1(x)
41
        out = self.bn1(out)
42
        out = self.relu(out)
43

44
        out = self.conv2(out)
45
        out = self.bn2(out)
46

47
        if self.downsample is not None:
48
            residual = self.downsample(x)
49

50
        out += residual
51
        out = self.relu(out)
52

53
        return out
54

55

56
class IRBlock(nn.Module):
57
    """Improved residual block (IR Block) used in the ResNetArcFace architecture.
58

59
    Args:
60
        inplanes (int): Channel number of inputs.
61
        planes (int): Channel number of outputs.
62
        stride (int): Stride in convolution. Default: 1.
63
        downsample (nn.Module): The downsample module. Default: None.
64
        use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
65
    """
66
    expansion = 1  # output channel expansion ratio
67

68
    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
69
        super(IRBlock, self).__init__()
70
        self.bn0 = nn.BatchNorm2d(inplanes)
71
        self.conv1 = conv3x3(inplanes, inplanes)
72
        self.bn1 = nn.BatchNorm2d(inplanes)
73
        self.prelu = nn.PReLU()
74
        self.conv2 = conv3x3(inplanes, planes, stride)
75
        self.bn2 = nn.BatchNorm2d(planes)
76
        self.downsample = downsample
77
        self.stride = stride
78
        self.use_se = use_se
79
        if self.use_se:
80
            self.se = SEBlock(planes)
81

82
    def forward(self, x):
83
        residual = x
84
        out = self.bn0(x)
85
        out = self.conv1(out)
86
        out = self.bn1(out)
87
        out = self.prelu(out)
88

89
        out = self.conv2(out)
90
        out = self.bn2(out)
91
        if self.use_se:
92
            out = self.se(out)
93

94
        if self.downsample is not None:
95
            residual = self.downsample(x)
96

97
        out += residual
98
        out = self.prelu(out)
99

100
        return out
101

102

103
class Bottleneck(nn.Module):
104
    """Bottleneck block used in the ResNetArcFace architecture.
105

106
    Args:
107
        inplanes (int): Channel number of inputs.
108
        planes (int): Channel number of outputs.
109
        stride (int): Stride in convolution. Default: 1.
110
        downsample (nn.Module): The downsample module. Default: None.
111
    """
112
    expansion = 4  # output channel expansion ratio
113

114
    def __init__(self, inplanes, planes, stride=1, downsample=None):
115
        super(Bottleneck, self).__init__()
116
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
117
        self.bn1 = nn.BatchNorm2d(planes)
118
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
119
        self.bn2 = nn.BatchNorm2d(planes)
120
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
121
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
122
        self.relu = nn.ReLU(inplace=True)
123
        self.downsample = downsample
124
        self.stride = stride
125

126
    def forward(self, x):
127
        residual = x
128

129
        out = self.conv1(x)
130
        out = self.bn1(out)
131
        out = self.relu(out)
132

133
        out = self.conv2(out)
134
        out = self.bn2(out)
135
        out = self.relu(out)
136

137
        out = self.conv3(out)
138
        out = self.bn3(out)
139

140
        if self.downsample is not None:
141
            residual = self.downsample(x)
142

143
        out += residual
144
        out = self.relu(out)
145

146
        return out
147

148

149
class SEBlock(nn.Module):
150
    """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
151

152
    Args:
153
        channel (int): Channel number of inputs.
154
        reduction (int): Channel reduction ration. Default: 16.
155
    """
156

157
    def __init__(self, channel, reduction=16):
158
        super(SEBlock, self).__init__()
159
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # pool to 1x1 without spatial information
160
        self.fc = nn.Sequential(
161
            nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
162
            nn.Sigmoid())
163

164
    def forward(self, x):
165
        b, c, _, _ = x.size()
166
        y = self.avg_pool(x).view(b, c)
167
        y = self.fc(y).view(b, c, 1, 1)
168
        return x * y
169

170

171
@ARCH_REGISTRY.register()
172
class ResNetArcFace(nn.Module):
173
    """ArcFace with ResNet architectures.
174

175
    Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
176

177
    Args:
178
        block (str): Block used in the ArcFace architecture.
179
        layers (tuple(int)): Block numbers in each layer.
180
        use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
181
    """
182

183
    def __init__(self, block, layers, use_se=True):
184
        if block == 'IRBlock':
185
            block = IRBlock
186
        self.inplanes = 64
187
        self.use_se = use_se
188
        super(ResNetArcFace, self).__init__()
189

190
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
191
        self.bn1 = nn.BatchNorm2d(64)
192
        self.prelu = nn.PReLU()
193
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
194
        self.layer1 = self._make_layer(block, 64, layers[0])
195
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
196
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
197
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
198
        self.bn4 = nn.BatchNorm2d(512)
199
        self.dropout = nn.Dropout()
200
        self.fc5 = nn.Linear(512 * 8 * 8, 512)
201
        self.bn5 = nn.BatchNorm1d(512)
202

203
        # initialization
204
        for m in self.modules():
205
            if isinstance(m, nn.Conv2d):
206
                nn.init.xavier_normal_(m.weight)
207
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
208
                nn.init.constant_(m.weight, 1)
209
                nn.init.constant_(m.bias, 0)
210
            elif isinstance(m, nn.Linear):
211
                nn.init.xavier_normal_(m.weight)
212
                nn.init.constant_(m.bias, 0)
213

214
    def _make_layer(self, block, planes, num_blocks, stride=1):
215
        downsample = None
216
        if stride != 1 or self.inplanes != planes * block.expansion:
217
            downsample = nn.Sequential(
218
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
219
                nn.BatchNorm2d(planes * block.expansion),
220
            )
221
        layers = []
222
        layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
223
        self.inplanes = planes
224
        for _ in range(1, num_blocks):
225
            layers.append(block(self.inplanes, planes, use_se=self.use_se))
226

227
        return nn.Sequential(*layers)
228

229
    def forward(self, x):
230
        x = self.conv1(x)
231
        x = self.bn1(x)
232
        x = self.prelu(x)
233
        x = self.maxpool(x)
234

235
        x = self.layer1(x)
236
        x = self.layer2(x)
237
        x = self.layer3(x)
238
        x = self.layer4(x)
239
        x = self.bn4(x)
240
        x = self.dropout(x)
241
        x = x.view(x.size(0), -1)
242
        x = self.fc5(x)
243
        x = self.bn5(x)
244

245
        return x
246

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.