google-research
218 строк · 5.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Resnet block components."""
17import torch.nn as nn
18from cascaded_networks.models import custom_ops
19from cascaded_networks.models import tdl
20
21
22class HeadLayer(nn.Module):
23"""Head layer of ResNet."""
24
25def __init__(self, planes, norm_layer, **kwargs):
26"""Initialize head layer."""
27super(HeadLayer, self).__init__()
28self.cascaded = kwargs['cascaded']
29self.time_bn = kwargs.get('time_bn', kwargs['cascaded'])
30
31inplanes = 3
32
33if kwargs.get('imagenet', False):
34self.conv1 = nn.Conv2d(inplanes,
35planes,
36kernel_size=7,
37stride=2,
38padding=3,
39bias=False)
40else:
41self.conv1 = nn.Conv2d(inplanes,
42planes,
43kernel_size=3,
44stride=1,
45padding=1,
46bias=False)
47
48self.bn1 = norm_layer(planes)
49self.relu = nn.ReLU(inplace=True)
50self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
51
52if self.cascaded:
53tdl_mode = kwargs.get('tdl_mode', 'OSD')
54self.tdline = tdl.setup_tdl_kernel(tdl_mode, kwargs)
55
56def set_time(self, t):
57self.t = t
58if t == 0:
59self.tdline.reset()
60
61def forward(self, x):
62out = self.conv1(x)
63out = self.bn1(out, self.t) if self.time_bn else self.bn1(out)
64out = self.relu(out)
65out = self.maxpool(out)
66
67if self.cascaded:
68# Add delay line
69out = self.tdline(out)
70
71return out
72
73
74class BasicBlock(nn.Module):
75"""Basic resnet block."""
76expansion = 1
77
78def __init__(self,
79inplanes,
80planes,
81stride=1,
82downsample=None,
83norm_layer=None,
84**kwargs):
85"""Initialize basic block."""
86super(BasicBlock, self).__init__()
87
88self.cascaded = kwargs['cascaded']
89self.time_bn = kwargs.get('time_bn', kwargs['cascaded'])
90self.downsample = downsample
91self.stride = stride
92
93# Setup ops
94self.conv1 = custom_ops.conv3x3(inplanes, planes, stride)
95self.bn1 = norm_layer(planes)
96self.relu = nn.ReLU(inplace=True)
97self.conv2 = custom_ops.conv3x3(planes, planes)
98self.bn2 = norm_layer(planes)
99
100# TDL
101if self.cascaded:
102tdl_mode = kwargs.get('tdl_mode', 'OSD')
103self.tdline = tdl.setup_tdl_kernel(tdl_mode, kwargs)
104
105def set_time(self, t):
106self.t = t
107if t == 0:
108self.tdline.reset()
109
110def _residual_block(self, x):
111# Conv1
112out = self.conv1(x)
113out = self.bn1(out, self.t) if self.time_bn else self.bn1(out)
114out = self.relu(out)
115
116# Conv2
117out = self.conv2(out)
118out = self.bn2(out, self.t) if self.time_bn else self.bn2(out)
119
120return out
121
122def forward(self, x):
123# Identity
124identity = x
125if self.downsample is not None:
126identity = self.downsample(x)
127
128# Residual
129residual = self._residual_block(x)
130
131# TDL if cascaded
132if self.cascaded:
133residual = self.tdline(residual)
134
135# Identity + Residual
136out = residual + identity
137
138# Nonlinear activation
139out = self.relu(out)
140
141return out
142
143
144class Bottleneck(nn.Module):
145"""Bottleneck Block."""
146expansion = 4
147
148def __init__(self,
149inplanes,
150planes,
151stride=1,
152downsample=None,
153norm_layer=None,
154**kwargs):
155"""Initialize bottleneck block."""
156super(Bottleneck, self).__init__()
157base_width = 64
158width = int(planes * (base_width / 64.))
159
160self.downsample = downsample
161self.stride = stride
162self.cascaded = kwargs['cascaded']
163self.time_bn = kwargs.get('time_bn', kwargs['cascaded'])
164
165self.conv1 = custom_ops.conv1x1(inplanes, width)
166self.bn1 = norm_layer(width)
167self.conv2 = custom_ops.conv3x3(width, width, stride)
168self.bn2 = norm_layer(width)
169self.conv3 = custom_ops.conv1x1(width, planes * self.expansion)
170self.bn3 = norm_layer(planes * self.expansion)
171self.relu = nn.ReLU(inplace=True)
172
173if self.cascaded:
174tdl_mode = kwargs.get('tdl_mode', 'OSD')
175self.tdline = tdl.setup_tdl_kernel(tdl_mode, kwargs)
176
177def set_time(self, t):
178self.t = t
179if t == 0:
180self.tdline.reset()
181
182def _residual_block(self, x):
183# Conv 1
184out = self.conv1(x)
185out = self.bn1(out, self.t) if self.time_bn else self.bn1(out)
186out = self.relu(out)
187
188# Conv 2
189out = self.conv2(out)
190out = self.bn2(out, self.t) if self.time_bn else self.bn2(out)
191out = self.relu(out)
192
193# Conv 3
194out = self.conv3(out)
195out = self.bn3(out, self.t) if self.time_bn else self.bn3(out)
196
197return out
198
199def forward(self, x):
200# Identity
201identity = x
202if self.downsample is not None:
203identity = self.downsample(x)
204
205# Residual
206residual = self._residual_block(x)
207
208# TDL if cascaded
209if self.cascaded:
210residual = self.tdline(residual)
211
212# Identity + Residual
213out = residual + identity
214
215# Nonlinear activation
216out = self.relu(out)
217
218return out
219