google-research
195 строк · 6.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""ResNet handler.
17
18Adapted from
19https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
20
21Two primary changes from original ResNet code:
221) Tapped delay line op is added to the output of every residual computation
23- See project.models.layers & project.models.tdl
242) The timestep is set on the TDL in the forward pass
25"""
26import functools
27import numpy as np
28import torch
29import torch.nn as nn
30from cascaded_networks.models import custom_ops
31from cascaded_networks.models import layers as res_layers
32from cascaded_networks.models import model_utils
33
34
35class ResNet(nn.Module):
36"""Resnet base class."""
37
38def __init__(self, name, block, layers, num_classes, **kwargs):
39"""Initialize resnet."""
40super(ResNet, self).__init__()
41self.name = name
42self._layers_arch = layers
43self._cascaded = kwargs.get('cascaded', False)
44self._time_bn = kwargs.get('time_bn', self._cascaded)
45
46# Set up batch norm operation
47self._norm_layer_op = self._setup_bn_op(**kwargs)
48
49# Head layer
50self.inplanes = 64
51self.layer0 = res_layers.HeadLayer(self.inplanes,
52self._norm_layer_op,
53**kwargs)
54
55# Residual Layers
56self.layer1 = self._make_layer(block, 64, layers[0], **kwargs)
57self.layer2 = self._make_layer(block, 128, layers[1], stride=2, **kwargs)
58self.layer3 = self._make_layer(block, 256, layers[2], stride=2, **kwargs)
59self.layer4 = self._make_layer(block, 512, layers[3], stride=2, **kwargs)
60self.layers = [self.layer1, self.layer2, self.layer3, self.layer4]
61
62# Final layer
63self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
64final_bias = not kwargs.get('final_bias_off', False)
65self.fc = nn.Linear(512 * block.expansion, num_classes, bias=final_bias)
66
67# Weight initialization
68for m in self.modules():
69if isinstance(m, nn.Conv2d):
70nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
71elif isinstance(m, (self._norm_layer, nn.GroupNorm)):
72nn.init.constant_(m.weight, 1)
73nn.init.constant_(m.bias, 0)
74
75def _setup_bn_op(self, **kwargs):
76if self._cascaded:
77if self._time_bn:
78self._norm_layer = custom_ops.BatchNorm2d
79
80# Setup batchnorm opts
81self.bn_opts = kwargs['bn_opts']
82self.bn_opts['n_timesteps'] = self.timesteps
83norm_layer_op = functools.partial(self._norm_layer, self.bn_opts)
84else:
85self._norm_layer = nn.BatchNorm2d
86norm_layer_op = self._norm_layer
87else:
88self._norm_layer = nn.BatchNorm2d
89norm_layer_op = self._norm_layer
90
91return norm_layer_op
92
93def _make_layer(self, block, planes, blocks, stride=1, **kwargs):
94tdl_mode = kwargs.get('tdl_mode', 'OSD')
95tdl_alpha = kwargs.get('tdl_alpha', 0.0)
96noise_var = kwargs.get('noise_var', 0.0)
97
98downsample = None
99if stride != 1 or self.inplanes != planes * block.expansion:
100downsample = nn.Sequential(
101custom_ops.conv1x1(self.inplanes, planes * block.expansion, stride),
102)
103layers = []
104layers.append(
105block(self.inplanes,
106planes,
107stride,
108downsample,
109self._norm_layer_op,
110tdl_alpha=tdl_alpha,
111tdl_mode=tdl_mode,
112noise_var=noise_var,
113cascaded=self._cascaded,
114time_bn=self._time_bn))
115
116self.inplanes = planes * block.expansion
117for _ in range(1, blocks):
118layers.append(
119block(self.inplanes,
120planes,
121norm_layer=self._norm_layer_op,
122tdl_alpha=tdl_alpha,
123tdl_mode=tdl_mode,
124noise_var=noise_var,
125cascaded=self._cascaded,
126time_bn=self._time_bn))
127
128return nn.Sequential(*layers)
129
130@property
131def timesteps(self):
132if self._cascaded:
133n_timesteps = np.sum(self._layers_arch) + 1
134else:
135n_timesteps = 1
136return n_timesteps
137
138def _set_time(self, t):
139self.layer0.set_time(t)
140for layer in self.layers:
141for block in layer:
142block.set_time(t)
143
144def forward(self, x, t):
145# Set time on all blocks
146if self._cascaded:
147self._set_time(t)
148
149# Head layer
150out = self.layer0(x)
151
152# Res Layers
153for layer in self.layers:
154out = layer(out)
155
156# Final layer
157out = self.avgpool(out)
158out = torch.flatten(out, 1)
159
160# Classification
161out = self.fc(out)
162
163return out
164
165
166def make_resnet(arch, block, layers, pretrained, **kwargs):
167model = ResNet(arch, block, layers, **kwargs)
168if pretrained:
169model = model_utils.load_model(model, kwargs)
170return model
171
172
173def resnet18(pretrained=False, **kwargs):
174return make_resnet('resnet18', res_layers.BasicBlock, [2, 2, 2, 2],
175pretrained, **kwargs)
176
177
178def resnet34(pretrained=False, **kwargs):
179return make_resnet('resnet34', res_layers.BasicBlock, [3, 4, 6, 3],
180pretrained, **kwargs)
181
182
183def resnet50(pretrained=False, **kwargs):
184return make_resnet('resnet50', res_layers.Bottleneck, [3, 4, 6, 3],
185pretrained, **kwargs)
186
187
188def resnet101(pretrained=False, **kwargs):
189return make_resnet('resnet101', res_layers.Bottleneck, [3, 4, 23, 3],
190pretrained, **kwargs)
191
192
193def resnet152(pretrained=False, **kwargs):
194return make_resnet('resnet152', res_layers.Bottleneck, [3, 8, 36, 3],
195pretrained, **kwargs)
196