pytorch-image-models

Форк
0
260 строк · 8.3 Кб
1
""" MLP module w/ dropout and configurable activation layer
2

3
Hacked together by / Copyright 2020 Ross Wightman
4
"""
5
from functools import partial
6

7
from torch import nn as nn
8

9
from .grn import GlobalResponseNorm
10
from .helpers import to_2tuple
11

12

13
class Mlp(nn.Module):
14
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
15
    """
16
    def __init__(
17
            self,
18
            in_features,
19
            hidden_features=None,
20
            out_features=None,
21
            act_layer=nn.GELU,
22
            norm_layer=None,
23
            bias=True,
24
            drop=0.,
25
            use_conv=False,
26
    ):
27
        super().__init__()
28
        out_features = out_features or in_features
29
        hidden_features = hidden_features or in_features
30
        bias = to_2tuple(bias)
31
        drop_probs = to_2tuple(drop)
32
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
33

34
        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
35
        self.act = act_layer()
36
        self.drop1 = nn.Dropout(drop_probs[0])
37
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
38
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
39
        self.drop2 = nn.Dropout(drop_probs[1])
40

41
    def forward(self, x):
42
        x = self.fc1(x)
43
        x = self.act(x)
44
        x = self.drop1(x)
45
        x = self.norm(x)
46
        x = self.fc2(x)
47
        x = self.drop2(x)
48
        return x
49

50

51
class GluMlp(nn.Module):
52
    """ MLP w/ GLU style gating
53
    See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
54
    """
55
    def __init__(
56
            self,
57
            in_features,
58
            hidden_features=None,
59
            out_features=None,
60
            act_layer=nn.Sigmoid,
61
            norm_layer=None,
62
            bias=True,
63
            drop=0.,
64
            use_conv=False,
65
            gate_last=True,
66
    ):
67
        super().__init__()
68
        out_features = out_features or in_features
69
        hidden_features = hidden_features or in_features
70
        assert hidden_features % 2 == 0
71
        bias = to_2tuple(bias)
72
        drop_probs = to_2tuple(drop)
73
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
74
        self.chunk_dim = 1 if use_conv else -1
75
        self.gate_last = gate_last  # use second half of width for gate
76

77
        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
78
        self.act = act_layer()
79
        self.drop1 = nn.Dropout(drop_probs[0])
80
        self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity()
81
        self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1])
82
        self.drop2 = nn.Dropout(drop_probs[1])
83

84
    def init_weights(self):
85
        # override init of fc1 w/ gate portion set to weight near zero, bias=1
86
        fc1_mid = self.fc1.bias.shape[0] // 2
87
        nn.init.ones_(self.fc1.bias[fc1_mid:])
88
        nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
89

90
    def forward(self, x):
91
        x = self.fc1(x)
92
        x1, x2 = x.chunk(2, dim=self.chunk_dim)
93
        x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2
94
        x = self.drop1(x)
95
        x = self.norm(x)
96
        x = self.fc2(x)
97
        x = self.drop2(x)
98
        return x
99

100

101
SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False)
102

103

104
class SwiGLU(nn.Module):
105
    """ SwiGLU
106
    NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
107
    better matches some other common impl which makes mapping checkpoints simpler.
108
    """
109
    def __init__(
110
            self,
111
            in_features,
112
            hidden_features=None,
113
            out_features=None,
114
            act_layer=nn.SiLU,
115
            norm_layer=None,
116
            bias=True,
117
            drop=0.,
118
    ):
119
        super().__init__()
120
        out_features = out_features or in_features
121
        hidden_features = hidden_features or in_features
122
        bias = to_2tuple(bias)
123
        drop_probs = to_2tuple(drop)
124

125
        self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0])
126
        self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0])
127
        self.act = act_layer()
128
        self.drop1 = nn.Dropout(drop_probs[0])
129
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
130
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
131
        self.drop2 = nn.Dropout(drop_probs[1])
132

133
    def init_weights(self):
134
        # override init of fc1 w/ gate portion set to weight near zero, bias=1
135
        nn.init.ones_(self.fc1_g.bias)
136
        nn.init.normal_(self.fc1_g.weight, std=1e-6)
137

138
    def forward(self, x):
139
        x_gate = self.fc1_g(x)
140
        x = self.fc1_x(x)
141
        x = self.act(x_gate) * x
142
        x = self.drop1(x)
143
        x = self.norm(x)
144
        x = self.fc2(x)
145
        x = self.drop2(x)
146
        return x
147

148

149
class GatedMlp(nn.Module):
150
    """ MLP as used in gMLP
151
    """
152
    def __init__(
153
            self,
154
            in_features,
155
            hidden_features=None,
156
            out_features=None,
157
            act_layer=nn.GELU,
158
            norm_layer=None,
159
            gate_layer=None,
160
            bias=True,
161
            drop=0.,
162
    ):
163
        super().__init__()
164
        out_features = out_features or in_features
165
        hidden_features = hidden_features or in_features
166
        bias = to_2tuple(bias)
167
        drop_probs = to_2tuple(drop)
168

169
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
170
        self.act = act_layer()
171
        self.drop1 = nn.Dropout(drop_probs[0])
172
        if gate_layer is not None:
173
            assert hidden_features % 2 == 0
174
            self.gate = gate_layer(hidden_features)
175
            hidden_features = hidden_features // 2  # FIXME base reduction on gate property?
176
        else:
177
            self.gate = nn.Identity()
178
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
179
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
180
        self.drop2 = nn.Dropout(drop_probs[1])
181

182
    def forward(self, x):
183
        x = self.fc1(x)
184
        x = self.act(x)
185
        x = self.drop1(x)
186
        x = self.gate(x)
187
        x = self.norm(x)
188
        x = self.fc2(x)
189
        x = self.drop2(x)
190
        return x
191

192

193
class ConvMlp(nn.Module):
194
    """ MLP using 1x1 convs that keeps spatial dims
195
    """
196
    def __init__(
197
            self,
198
            in_features,
199
            hidden_features=None,
200
            out_features=None,
201
            act_layer=nn.ReLU,
202
            norm_layer=None,
203
            bias=True,
204
            drop=0.,
205
    ):
206
        super().__init__()
207
        out_features = out_features or in_features
208
        hidden_features = hidden_features or in_features
209
        bias = to_2tuple(bias)
210

211
        self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])
212
        self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
213
        self.act = act_layer()
214
        self.drop = nn.Dropout(drop)
215
        self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])
216

217
    def forward(self, x):
218
        x = self.fc1(x)
219
        x = self.norm(x)
220
        x = self.act(x)
221
        x = self.drop(x)
222
        x = self.fc2(x)
223
        return x
224

225

226
class GlobalResponseNormMlp(nn.Module):
227
    """ MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
228
    """
229
    def __init__(
230
            self,
231
            in_features,
232
            hidden_features=None,
233
            out_features=None,
234
            act_layer=nn.GELU,
235
            bias=True,
236
            drop=0.,
237
            use_conv=False,
238
    ):
239
        super().__init__()
240
        out_features = out_features or in_features
241
        hidden_features = hidden_features or in_features
242
        bias = to_2tuple(bias)
243
        drop_probs = to_2tuple(drop)
244
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
245

246
        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
247
        self.act = act_layer()
248
        self.drop1 = nn.Dropout(drop_probs[0])
249
        self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv)
250
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
251
        self.drop2 = nn.Dropout(drop_probs[1])
252

253
    def forward(self, x):
254
        x = self.fc1(x)
255
        x = self.act(x)
256
        x = self.drop1(x)
257
        x = self.grn(x)
258
        x = self.fc2(x)
259
        x = self.drop2(x)
260
        return x
261

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.