pytorch-image-models

Форк
0
129 строк · 3.7 Кб
1
""" Conv2d + BN + Act
2

3
Hacked together by / Copyright 2020 Ross Wightman
4
"""
5
import functools
6
from torch import nn as nn
7

8
from .create_conv2d import create_conv2d
9
from .create_norm_act import get_norm_act_layer
10

11

12
class ConvNormAct(nn.Module):
13
    def __init__(
14
            self,
15
            in_channels,
16
            out_channels,
17
            kernel_size=1,
18
            stride=1,
19
            padding='',
20
            dilation=1,
21
            groups=1,
22
            bias=False,
23
            apply_act=True,
24
            norm_layer=nn.BatchNorm2d,
25
            norm_kwargs=None,
26
            act_layer=nn.ReLU,
27
            act_kwargs=None,
28
            drop_layer=None,
29
    ):
30
        super(ConvNormAct, self).__init__()
31
        norm_kwargs = norm_kwargs or {}
32
        act_kwargs = act_kwargs or {}
33

34
        self.conv = create_conv2d(
35
            in_channels, out_channels, kernel_size, stride=stride,
36
            padding=padding, dilation=dilation, groups=groups, bias=bias)
37

38
        # NOTE for backwards compatibility with models that use separate norm and act layer definitions
39
        norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
40
        # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
41
        if drop_layer:
42
            norm_kwargs['drop_layer'] = drop_layer
43
        self.bn = norm_act_layer(
44
            out_channels,
45
            apply_act=apply_act,
46
            act_kwargs=act_kwargs,
47
            **norm_kwargs,
48
        )
49

50
    @property
51
    def in_channels(self):
52
        return self.conv.in_channels
53

54
    @property
55
    def out_channels(self):
56
        return self.conv.out_channels
57

58
    def forward(self, x):
59
        x = self.conv(x)
60
        x = self.bn(x)
61
        return x
62

63

64
ConvBnAct = ConvNormAct
65

66

67
def create_aa(aa_layer, channels, stride=2, enable=True):
68
    if not aa_layer or not enable:
69
        return nn.Identity()
70
    if isinstance(aa_layer, functools.partial):
71
        if issubclass(aa_layer.func, nn.AvgPool2d):
72
            return aa_layer()
73
        else:
74
            return aa_layer(channels)
75
    elif issubclass(aa_layer, nn.AvgPool2d):
76
        return aa_layer(stride)
77
    else:
78
        return aa_layer(channels=channels, stride=stride)
79

80

81
class ConvNormActAa(nn.Module):
82
    def __init__(
83
            self,
84
            in_channels,
85
            out_channels,
86
            kernel_size=1,
87
            stride=1,
88
            padding='',
89
            dilation=1,
90
            groups=1,
91
            bias=False,
92
            apply_act=True,
93
            norm_layer=nn.BatchNorm2d,
94
            norm_kwargs=None,
95
            act_layer=nn.ReLU,
96
            act_kwargs=None,
97
            aa_layer=None,
98
            drop_layer=None,
99
    ):
100
        super(ConvNormActAa, self).__init__()
101
        use_aa = aa_layer is not None and stride == 2
102
        norm_kwargs = norm_kwargs or {}
103
        act_kwargs = act_kwargs or {}
104

105
        self.conv = create_conv2d(
106
            in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
107
            padding=padding, dilation=dilation, groups=groups, bias=bias)
108

109
        # NOTE for backwards compatibility with models that use separate norm and act layer definitions
110
        norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
111
        # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
112
        if drop_layer:
113
            norm_kwargs['drop_layer'] = drop_layer
114
        self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs)
115
        self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
116

117
    @property
118
    def in_channels(self):
119
        return self.conv.in_channels
120

121
    @property
122
    def out_channels(self):
123
        return self.conv.out_channels
124

125
    def forward(self, x):
126
        x = self.conv(x)
127
        x = self.bn(x)
128
        x = self.aa(x)
129
        return x
130

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.