google-research

Форк
0
218 строк · 5.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Resnet block components."""
17
import torch.nn as nn
18
from cascaded_networks.models import custom_ops
19
from cascaded_networks.models import tdl
20

21

22
class HeadLayer(nn.Module):
23
  """Head layer of ResNet."""
24

25
  def __init__(self, planes, norm_layer, **kwargs):
26
    """Initialize head layer."""
27
    super(HeadLayer, self).__init__()
28
    self.cascaded = kwargs['cascaded']
29
    self.time_bn = kwargs.get('time_bn', kwargs['cascaded'])
30

31
    inplanes = 3
32

33
    if kwargs.get('imagenet', False):
34
      self.conv1 = nn.Conv2d(inplanes,
35
                             planes,
36
                             kernel_size=7,
37
                             stride=2,
38
                             padding=3,
39
                             bias=False)
40
    else:
41
      self.conv1 = nn.Conv2d(inplanes,
42
                             planes,
43
                             kernel_size=3,
44
                             stride=1,
45
                             padding=1,
46
                             bias=False)
47

48
    self.bn1 = norm_layer(planes)
49
    self.relu = nn.ReLU(inplace=True)
50
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
51

52
    if self.cascaded:
53
      tdl_mode = kwargs.get('tdl_mode', 'OSD')
54
      self.tdline = tdl.setup_tdl_kernel(tdl_mode, kwargs)
55

56
  def set_time(self, t):
57
    self.t = t
58
    if t == 0:
59
      self.tdline.reset()
60

61
  def forward(self, x):
62
    out = self.conv1(x)
63
    out = self.bn1(out, self.t) if self.time_bn else self.bn1(out)
64
    out = self.relu(out)
65
    out = self.maxpool(out)
66

67
    if self.cascaded:
68
      # Add delay line
69
      out = self.tdline(out)
70

71
    return out
72

73

74
class BasicBlock(nn.Module):
75
  """Basic resnet block."""
76
  expansion = 1
77

78
  def __init__(self,
79
               inplanes,
80
               planes,
81
               stride=1,
82
               downsample=None,
83
               norm_layer=None,
84
               **kwargs):
85
    """Initialize basic block."""
86
    super(BasicBlock, self).__init__()
87

88
    self.cascaded = kwargs['cascaded']
89
    self.time_bn = kwargs.get('time_bn', kwargs['cascaded'])
90
    self.downsample = downsample
91
    self.stride = stride
92

93
    # Setup ops
94
    self.conv1 = custom_ops.conv3x3(inplanes, planes, stride)
95
    self.bn1 = norm_layer(planes)
96
    self.relu = nn.ReLU(inplace=True)
97
    self.conv2 = custom_ops.conv3x3(planes, planes)
98
    self.bn2 = norm_layer(planes)
99

100
    # TDL
101
    if self.cascaded:
102
      tdl_mode = kwargs.get('tdl_mode', 'OSD')
103
      self.tdline = tdl.setup_tdl_kernel(tdl_mode, kwargs)
104

105
  def set_time(self, t):
106
    self.t = t
107
    if t == 0:
108
      self.tdline.reset()
109

110
  def _residual_block(self, x):
111
    # Conv1
112
    out = self.conv1(x)
113
    out = self.bn1(out, self.t) if self.time_bn else self.bn1(out)
114
    out = self.relu(out)
115

116
    # Conv2
117
    out = self.conv2(out)
118
    out = self.bn2(out, self.t) if self.time_bn else self.bn2(out)
119

120
    return out
121

122
  def forward(self, x):
123
    # Identity
124
    identity = x
125
    if self.downsample is not None:
126
      identity = self.downsample(x)
127

128
    # Residual
129
    residual = self._residual_block(x)
130

131
    # TDL if cascaded
132
    if self.cascaded:
133
      residual = self.tdline(residual)
134

135
    # Identity + Residual
136
    out = residual + identity
137

138
    # Nonlinear activation
139
    out = self.relu(out)
140

141
    return out
142

143

144
class Bottleneck(nn.Module):
145
  """Bottleneck Block."""
146
  expansion = 4
147

148
  def __init__(self,
149
               inplanes,
150
               planes,
151
               stride=1,
152
               downsample=None,
153
               norm_layer=None,
154
               **kwargs):
155
    """Initialize bottleneck block."""
156
    super(Bottleneck, self).__init__()
157
    base_width = 64
158
    width = int(planes * (base_width / 64.))
159

160
    self.downsample = downsample
161
    self.stride = stride
162
    self.cascaded = kwargs['cascaded']
163
    self.time_bn = kwargs.get('time_bn', kwargs['cascaded'])
164

165
    self.conv1 = custom_ops.conv1x1(inplanes, width)
166
    self.bn1 = norm_layer(width)
167
    self.conv2 = custom_ops.conv3x3(width, width, stride)
168
    self.bn2 = norm_layer(width)
169
    self.conv3 = custom_ops.conv1x1(width, planes * self.expansion)
170
    self.bn3 = norm_layer(planes * self.expansion)
171
    self.relu = nn.ReLU(inplace=True)
172

173
    if self.cascaded:
174
      tdl_mode = kwargs.get('tdl_mode', 'OSD')
175
      self.tdline = tdl.setup_tdl_kernel(tdl_mode, kwargs)
176

177
  def set_time(self, t):
178
    self.t = t
179
    if t == 0:
180
      self.tdline.reset()
181

182
  def _residual_block(self, x):
183
    # Conv 1
184
    out = self.conv1(x)
185
    out = self.bn1(out, self.t) if self.time_bn else self.bn1(out)
186
    out = self.relu(out)
187

188
    # Conv 2
189
    out = self.conv2(out)
190
    out = self.bn2(out, self.t) if self.time_bn else self.bn2(out)
191
    out = self.relu(out)
192

193
    # Conv 3
194
    out = self.conv3(out)
195
    out = self.bn3(out, self.t) if self.time_bn else self.bn3(out)
196

197
    return out
198

199
  def forward(self, x):
200
    # Identity
201
    identity = x
202
    if self.downsample is not None:
203
      identity = self.downsample(x)
204

205
    # Residual
206
    residual = self._residual_block(x)
207

208
    # TDL if cascaded
209
    if self.cascaded:
210
      residual = self.tdline(residual)
211

212
    # Identity + Residual
213
    out = residual + identity
214

215
    # Nonlinear activation
216
    out = self.relu(out)
217

218
    return out
219

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.