lama
20 строк · 609.0 Байт
1import torch.nn as nn2
3
4class SELayer(nn.Module):5def __init__(self, channel, reduction=16):6super(SELayer, self).__init__()7self.avg_pool = nn.AdaptiveAvgPool2d(1)8self.fc = nn.Sequential(9nn.Linear(channel, channel // reduction, bias=False),10nn.ReLU(inplace=True),11nn.Linear(channel // reduction, channel, bias=False),12nn.Sigmoid()13)14
15def forward(self, x):16b, c, _, _ = x.size()17y = self.avg_pool(x).view(b, c)18y = self.fc(y).view(b, c, 1, 1)19res = x * y.expand_as(x)20return res21