pytorch-image-models
129 строк · 3.7 Кб
1""" Conv2d + BN + Act
2
3Hacked together by / Copyright 2020 Ross Wightman
4"""
5import functools
6from torch import nn as nn
7
8from .create_conv2d import create_conv2d
9from .create_norm_act import get_norm_act_layer
10
11
12class ConvNormAct(nn.Module):
13def __init__(
14self,
15in_channels,
16out_channels,
17kernel_size=1,
18stride=1,
19padding='',
20dilation=1,
21groups=1,
22bias=False,
23apply_act=True,
24norm_layer=nn.BatchNorm2d,
25norm_kwargs=None,
26act_layer=nn.ReLU,
27act_kwargs=None,
28drop_layer=None,
29):
30super(ConvNormAct, self).__init__()
31norm_kwargs = norm_kwargs or {}
32act_kwargs = act_kwargs or {}
33
34self.conv = create_conv2d(
35in_channels, out_channels, kernel_size, stride=stride,
36padding=padding, dilation=dilation, groups=groups, bias=bias)
37
38# NOTE for backwards compatibility with models that use separate norm and act layer definitions
39norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
40# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
41if drop_layer:
42norm_kwargs['drop_layer'] = drop_layer
43self.bn = norm_act_layer(
44out_channels,
45apply_act=apply_act,
46act_kwargs=act_kwargs,
47**norm_kwargs,
48)
49
50@property
51def in_channels(self):
52return self.conv.in_channels
53
54@property
55def out_channels(self):
56return self.conv.out_channels
57
58def forward(self, x):
59x = self.conv(x)
60x = self.bn(x)
61return x
62
63
64ConvBnAct = ConvNormAct
65
66
67def create_aa(aa_layer, channels, stride=2, enable=True):
68if not aa_layer or not enable:
69return nn.Identity()
70if isinstance(aa_layer, functools.partial):
71if issubclass(aa_layer.func, nn.AvgPool2d):
72return aa_layer()
73else:
74return aa_layer(channels)
75elif issubclass(aa_layer, nn.AvgPool2d):
76return aa_layer(stride)
77else:
78return aa_layer(channels=channels, stride=stride)
79
80
81class ConvNormActAa(nn.Module):
82def __init__(
83self,
84in_channels,
85out_channels,
86kernel_size=1,
87stride=1,
88padding='',
89dilation=1,
90groups=1,
91bias=False,
92apply_act=True,
93norm_layer=nn.BatchNorm2d,
94norm_kwargs=None,
95act_layer=nn.ReLU,
96act_kwargs=None,
97aa_layer=None,
98drop_layer=None,
99):
100super(ConvNormActAa, self).__init__()
101use_aa = aa_layer is not None and stride == 2
102norm_kwargs = norm_kwargs or {}
103act_kwargs = act_kwargs or {}
104
105self.conv = create_conv2d(
106in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
107padding=padding, dilation=dilation, groups=groups, bias=bias)
108
109# NOTE for backwards compatibility with models that use separate norm and act layer definitions
110norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
111# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
112if drop_layer:
113norm_kwargs['drop_layer'] = drop_layer
114self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs)
115self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
116
117@property
118def in_channels(self):
119return self.conv.in_channels
120
121@property
122def out_channels(self):
123return self.conv.out_channels
124
125def forward(self, x):
126x = self.conv(x)
127x = self.bn(x)
128x = self.aa(x)
129return x
130