pytorch-image-models
260 строк · 8.3 Кб
1""" MLP module w/ dropout and configurable activation layer
2
3Hacked together by / Copyright 2020 Ross Wightman
4"""
5from functools import partial6
7from torch import nn as nn8
9from .grn import GlobalResponseNorm10from .helpers import to_2tuple11
12
13class Mlp(nn.Module):14""" MLP as used in Vision Transformer, MLP-Mixer and related networks15"""
16def __init__(17self,18in_features,19hidden_features=None,20out_features=None,21act_layer=nn.GELU,22norm_layer=None,23bias=True,24drop=0.,25use_conv=False,26):27super().__init__()28out_features = out_features or in_features29hidden_features = hidden_features or in_features30bias = to_2tuple(bias)31drop_probs = to_2tuple(drop)32linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear33
34self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])35self.act = act_layer()36self.drop1 = nn.Dropout(drop_probs[0])37self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()38self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])39self.drop2 = nn.Dropout(drop_probs[1])40
41def forward(self, x):42x = self.fc1(x)43x = self.act(x)44x = self.drop1(x)45x = self.norm(x)46x = self.fc2(x)47x = self.drop2(x)48return x49
50
51class GluMlp(nn.Module):52""" MLP w/ GLU style gating53See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
54"""
55def __init__(56self,57in_features,58hidden_features=None,59out_features=None,60act_layer=nn.Sigmoid,61norm_layer=None,62bias=True,63drop=0.,64use_conv=False,65gate_last=True,66):67super().__init__()68out_features = out_features or in_features69hidden_features = hidden_features or in_features70assert hidden_features % 2 == 071bias = to_2tuple(bias)72drop_probs = to_2tuple(drop)73linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear74self.chunk_dim = 1 if use_conv else -175self.gate_last = gate_last # use second half of width for gate76
77self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])78self.act = act_layer()79self.drop1 = nn.Dropout(drop_probs[0])80self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity()81self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1])82self.drop2 = nn.Dropout(drop_probs[1])83
84def init_weights(self):85# override init of fc1 w/ gate portion set to weight near zero, bias=186fc1_mid = self.fc1.bias.shape[0] // 287nn.init.ones_(self.fc1.bias[fc1_mid:])88nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)89
90def forward(self, x):91x = self.fc1(x)92x1, x2 = x.chunk(2, dim=self.chunk_dim)93x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x294x = self.drop1(x)95x = self.norm(x)96x = self.fc2(x)97x = self.drop2(x)98return x99
100
101SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False)102
103
104class SwiGLU(nn.Module):105""" SwiGLU106NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
107better matches some other common impl which makes mapping checkpoints simpler.
108"""
109def __init__(110self,111in_features,112hidden_features=None,113out_features=None,114act_layer=nn.SiLU,115norm_layer=None,116bias=True,117drop=0.,118):119super().__init__()120out_features = out_features or in_features121hidden_features = hidden_features or in_features122bias = to_2tuple(bias)123drop_probs = to_2tuple(drop)124
125self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0])126self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0])127self.act = act_layer()128self.drop1 = nn.Dropout(drop_probs[0])129self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()130self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])131self.drop2 = nn.Dropout(drop_probs[1])132
133def init_weights(self):134# override init of fc1 w/ gate portion set to weight near zero, bias=1135nn.init.ones_(self.fc1_g.bias)136nn.init.normal_(self.fc1_g.weight, std=1e-6)137
138def forward(self, x):139x_gate = self.fc1_g(x)140x = self.fc1_x(x)141x = self.act(x_gate) * x142x = self.drop1(x)143x = self.norm(x)144x = self.fc2(x)145x = self.drop2(x)146return x147
148
149class GatedMlp(nn.Module):150""" MLP as used in gMLP151"""
152def __init__(153self,154in_features,155hidden_features=None,156out_features=None,157act_layer=nn.GELU,158norm_layer=None,159gate_layer=None,160bias=True,161drop=0.,162):163super().__init__()164out_features = out_features or in_features165hidden_features = hidden_features or in_features166bias = to_2tuple(bias)167drop_probs = to_2tuple(drop)168
169self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])170self.act = act_layer()171self.drop1 = nn.Dropout(drop_probs[0])172if gate_layer is not None:173assert hidden_features % 2 == 0174self.gate = gate_layer(hidden_features)175hidden_features = hidden_features // 2 # FIXME base reduction on gate property?176else:177self.gate = nn.Identity()178self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()179self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])180self.drop2 = nn.Dropout(drop_probs[1])181
182def forward(self, x):183x = self.fc1(x)184x = self.act(x)185x = self.drop1(x)186x = self.gate(x)187x = self.norm(x)188x = self.fc2(x)189x = self.drop2(x)190return x191
192
193class ConvMlp(nn.Module):194""" MLP using 1x1 convs that keeps spatial dims195"""
196def __init__(197self,198in_features,199hidden_features=None,200out_features=None,201act_layer=nn.ReLU,202norm_layer=None,203bias=True,204drop=0.,205):206super().__init__()207out_features = out_features or in_features208hidden_features = hidden_features or in_features209bias = to_2tuple(bias)210
211self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])212self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()213self.act = act_layer()214self.drop = nn.Dropout(drop)215self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])216
217def forward(self, x):218x = self.fc1(x)219x = self.norm(x)220x = self.act(x)221x = self.drop(x)222x = self.fc2(x)223return x224
225
226class GlobalResponseNormMlp(nn.Module):227""" MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d228"""
229def __init__(230self,231in_features,232hidden_features=None,233out_features=None,234act_layer=nn.GELU,235bias=True,236drop=0.,237use_conv=False,238):239super().__init__()240out_features = out_features or in_features241hidden_features = hidden_features or in_features242bias = to_2tuple(bias)243drop_probs = to_2tuple(drop)244linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear245
246self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])247self.act = act_layer()248self.drop1 = nn.Dropout(drop_probs[0])249self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv)250self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])251self.drop2 = nn.Dropout(drop_probs[1])252
253def forward(self, x):254x = self.fc1(x)255x = self.act(x)256x = self.drop1(x)257x = self.grn(x)258x = self.fc2(x)259x = self.drop2(x)260return x261