google-research

Форк
0
195 строк · 6.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""ResNet handler.
17

18
  Adapted from
19
  https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
20

21
  Two primary changes from original ResNet code:
22
  1) Tapped delay line op is added to the output of every residual computation
23
    - See project.models.layers & project.models.tdl
24
  2) The timestep is set on the TDL in the forward pass
25
"""
26
import functools
27
import numpy as np
28
import torch
29
import torch.nn as nn
30
from cascaded_networks.models import custom_ops
31
from cascaded_networks.models import layers as res_layers
32
from cascaded_networks.models import model_utils
33

34

35
class ResNet(nn.Module):
36
  """Resnet base class."""
37

38
  def __init__(self, name, block, layers, num_classes, **kwargs):
39
    """Initialize resnet."""
40
    super(ResNet, self).__init__()
41
    self.name = name
42
    self._layers_arch = layers
43
    self._cascaded = kwargs.get('cascaded', False)
44
    self._time_bn = kwargs.get('time_bn', self._cascaded)
45

46
    # Set up batch norm operation
47
    self._norm_layer_op = self._setup_bn_op(**kwargs)
48

49
    # Head layer
50
    self.inplanes = 64
51
    self.layer0 = res_layers.HeadLayer(self.inplanes,
52
                                       self._norm_layer_op,
53
                                       **kwargs)
54

55
    # Residual Layers
56
    self.layer1 = self._make_layer(block, 64, layers[0], **kwargs)
57
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2, **kwargs)
58
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2, **kwargs)
59
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2, **kwargs)
60
    self.layers = [self.layer1, self.layer2, self.layer3, self.layer4]
61

62
    # Final layer
63
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
64
    final_bias = not kwargs.get('final_bias_off', False)
65
    self.fc = nn.Linear(512 * block.expansion, num_classes, bias=final_bias)
66

67
    # Weight initialization
68
    for m in self.modules():
69
      if isinstance(m, nn.Conv2d):
70
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
71
      elif isinstance(m, (self._norm_layer, nn.GroupNorm)):
72
        nn.init.constant_(m.weight, 1)
73
        nn.init.constant_(m.bias, 0)
74

75
  def _setup_bn_op(self, **kwargs):
76
    if self._cascaded:
77
      if self._time_bn:
78
        self._norm_layer = custom_ops.BatchNorm2d
79

80
        # Setup batchnorm opts
81
        self.bn_opts = kwargs['bn_opts']
82
        self.bn_opts['n_timesteps'] = self.timesteps
83
        norm_layer_op = functools.partial(self._norm_layer, self.bn_opts)
84
      else:
85
        self._norm_layer = nn.BatchNorm2d
86
        norm_layer_op = self._norm_layer
87
    else:
88
      self._norm_layer = nn.BatchNorm2d
89
      norm_layer_op = self._norm_layer
90

91
    return norm_layer_op
92

93
  def _make_layer(self, block, planes, blocks, stride=1, **kwargs):
94
    tdl_mode = kwargs.get('tdl_mode', 'OSD')
95
    tdl_alpha = kwargs.get('tdl_alpha', 0.0)
96
    noise_var = kwargs.get('noise_var', 0.0)
97

98
    downsample = None
99
    if stride != 1 or self.inplanes != planes * block.expansion:
100
      downsample = nn.Sequential(
101
          custom_ops.conv1x1(self.inplanes, planes * block.expansion, stride),
102
      )
103
    layers = []
104
    layers.append(
105
        block(self.inplanes,
106
              planes,
107
              stride,
108
              downsample,
109
              self._norm_layer_op,
110
              tdl_alpha=tdl_alpha,
111
              tdl_mode=tdl_mode,
112
              noise_var=noise_var,
113
              cascaded=self._cascaded,
114
              time_bn=self._time_bn))
115

116
    self.inplanes = planes * block.expansion
117
    for _ in range(1, blocks):
118
      layers.append(
119
          block(self.inplanes,
120
                planes,
121
                norm_layer=self._norm_layer_op,
122
                tdl_alpha=tdl_alpha,
123
                tdl_mode=tdl_mode,
124
                noise_var=noise_var,
125
                cascaded=self._cascaded,
126
                time_bn=self._time_bn))
127

128
    return nn.Sequential(*layers)
129

130
  @property
131
  def timesteps(self):
132
    if self._cascaded:
133
      n_timesteps = np.sum(self._layers_arch) + 1
134
    else:
135
      n_timesteps = 1
136
    return n_timesteps
137

138
  def _set_time(self, t):
139
    self.layer0.set_time(t)
140
    for layer in self.layers:
141
      for block in layer:
142
        block.set_time(t)
143

144
  def forward(self, x, t):
145
    # Set time on all blocks
146
    if self._cascaded:
147
      self._set_time(t)
148

149
    # Head layer
150
    out = self.layer0(x)
151

152
    # Res Layers
153
    for layer in self.layers:
154
      out = layer(out)
155

156
    # Final layer
157
    out = self.avgpool(out)
158
    out = torch.flatten(out, 1)
159

160
    # Classification
161
    out = self.fc(out)
162

163
    return out
164

165

166
def make_resnet(arch, block, layers, pretrained, **kwargs):
167
  model = ResNet(arch, block, layers, **kwargs)
168
  if pretrained:
169
    model = model_utils.load_model(model, kwargs)
170
  return model
171

172

173
def resnet18(pretrained=False, **kwargs):
174
  return make_resnet('resnet18', res_layers.BasicBlock, [2, 2, 2, 2],
175
                     pretrained, **kwargs)
176

177

178
def resnet34(pretrained=False, **kwargs):
179
  return make_resnet('resnet34', res_layers.BasicBlock, [3, 4, 6, 3],
180
                     pretrained, **kwargs)
181

182

183
def resnet50(pretrained=False, **kwargs):
184
  return make_resnet('resnet50', res_layers.Bottleneck, [3, 4, 6, 3],
185
                     pretrained, **kwargs)
186

187

188
def resnet101(pretrained=False, **kwargs):
189
  return make_resnet('resnet101', res_layers.Bottleneck, [3, 4, 23, 3],
190
                     pretrained, **kwargs)
191

192

193
def resnet152(pretrained=False, **kwargs):
194
  return make_resnet('resnet152', res_layers.Bottleneck, [3, 8, 36, 3],
195
                     pretrained, **kwargs)
196

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.