pytorch-image-models

Форк
0
84 строки · 3.0 Кб
1
""" Split Attention Conv2d (for ResNeSt Models)
2

3
Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
4

5
Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
6

7
Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
8
"""
9
import torch
10
import torch.nn.functional as F
11
from torch import nn
12

13
from .helpers import make_divisible
14

15

16
class RadixSoftmax(nn.Module):
17
    def __init__(self, radix, cardinality):
18
        super(RadixSoftmax, self).__init__()
19
        self.radix = radix
20
        self.cardinality = cardinality
21

22
    def forward(self, x):
23
        batch = x.size(0)
24
        if self.radix > 1:
25
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
26
            x = F.softmax(x, dim=1)
27
            x = x.reshape(batch, -1)
28
        else:
29
            x = torch.sigmoid(x)
30
        return x
31

32

33
class SplitAttn(nn.Module):
34
    """Split-Attention (aka Splat)
35
    """
36
    def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
37
                 dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
38
                 act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs):
39
        super(SplitAttn, self).__init__()
40
        out_channels = out_channels or in_channels
41
        self.radix = radix
42
        mid_chs = out_channels * radix
43
        if rd_channels is None:
44
            attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
45
        else:
46
            attn_chs = rd_channels * radix
47

48
        padding = kernel_size // 2 if padding is None else padding
49
        self.conv = nn.Conv2d(
50
            in_channels, mid_chs, kernel_size, stride, padding, dilation,
51
            groups=groups * radix, bias=bias, **kwargs)
52
        self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
53
        self.drop = drop_layer() if drop_layer is not None else nn.Identity()
54
        self.act0 = act_layer(inplace=True)
55
        self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
56
        self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
57
        self.act1 = act_layer(inplace=True)
58
        self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
59
        self.rsoftmax = RadixSoftmax(radix, groups)
60

61
    def forward(self, x):
62
        x = self.conv(x)
63
        x = self.bn0(x)
64
        x = self.drop(x)
65
        x = self.act0(x)
66

67
        B, RC, H, W = x.shape
68
        if self.radix > 1:
69
            x = x.reshape((B, self.radix, RC // self.radix, H, W))
70
            x_gap = x.sum(dim=1)
71
        else:
72
            x_gap = x
73
        x_gap = x_gap.mean((2, 3), keepdim=True)
74
        x_gap = self.fc1(x_gap)
75
        x_gap = self.bn1(x_gap)
76
        x_gap = self.act1(x_gap)
77
        x_attn = self.fc2(x_gap)
78

79
        x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
80
        if self.radix > 1:
81
            out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
82
        else:
83
            out = x * x_attn
84
        return out.contiguous()
85

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.