pytorch-image-models
465 строк · 17.0 Кб
1""" Normalization + Activation Layers
2
3Provides Norm+Act fns for standard PyTorch norm layers such as
4* BatchNorm
5* GroupNorm
6* LayerNorm
7
8This allows swapping with alternative layers that are natively both norm + act such as
9* EvoNorm (evo_norm.py)
10* FilterResponseNorm (filter_response_norm.py)
11* InplaceABN (inplace_abn.py)
12
13Hacked together by / Copyright 2022 Ross Wightman
14"""
15from typing import Union, List, Optional, Any
16
17import torch
18from torch import nn as nn
19from torch.nn import functional as F
20from torchvision.ops.misc import FrozenBatchNorm2d
21
22from .create_act import get_act_layer
23from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
24from .trace_utils import _assert
25
26
27def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
28act_layer = get_act_layer(act_layer) # string -> nn.Module
29act_kwargs = act_kwargs or {}
30if act_layer is not None and apply_act:
31if inplace:
32act_kwargs['inplace'] = inplace
33act = act_layer(**act_kwargs)
34else:
35act = nn.Identity()
36return act
37
38
39class BatchNormAct2d(nn.BatchNorm2d):
40"""BatchNorm + Activation
41
42This module performs BatchNorm + Activation in a manner that will remain backwards
43compatible with weights trained with separate bn, act. This is why we inherit from BN
44instead of composing it as a .bn member.
45"""
46def __init__(
47self,
48num_features,
49eps=1e-5,
50momentum=0.1,
51affine=True,
52track_running_stats=True,
53apply_act=True,
54act_layer=nn.ReLU,
55act_kwargs=None,
56inplace=True,
57drop_layer=None,
58device=None,
59dtype=None,
60):
61try:
62factory_kwargs = {'device': device, 'dtype': dtype}
63super(BatchNormAct2d, self).__init__(
64num_features,
65eps=eps,
66momentum=momentum,
67affine=affine,
68track_running_stats=track_running_stats,
69**factory_kwargs,
70)
71except TypeError:
72# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
73super(BatchNormAct2d, self).__init__(
74num_features,
75eps=eps,
76momentum=momentum,
77affine=affine,
78track_running_stats=track_running_stats,
79)
80self.drop = drop_layer() if drop_layer is not None else nn.Identity()
81self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
82
83def forward(self, x):
84# cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
85_assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')
86
87# exponential_average_factor is set to self.momentum
88# (when it is available) only so that it gets updated
89# in ONNX graph when this node is exported to ONNX.
90if self.momentum is None:
91exponential_average_factor = 0.0
92else:
93exponential_average_factor = self.momentum
94
95if self.training and self.track_running_stats:
96# TODO: if statement only here to tell the jit to skip emitting this when it is None
97if self.num_batches_tracked is not None: # type: ignore[has-type]
98self.num_batches_tracked.add_(1) # type: ignore[has-type]
99if self.momentum is None: # use cumulative moving average
100exponential_average_factor = 1.0 / float(self.num_batches_tracked)
101else: # use exponential moving average
102exponential_average_factor = self.momentum
103
104r"""
105Decide whether the mini-batch stats should be used for normalization rather than the buffers.
106Mini-batch stats are used in training mode, and in eval mode when buffers are None.
107"""
108if self.training:
109bn_training = True
110else:
111bn_training = (self.running_mean is None) and (self.running_var is None)
112
113r"""
114Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
115passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
116used for normalization (i.e. in eval mode when buffers are not None).
117"""
118x = F.batch_norm(
119x,
120# If buffers are not to be tracked, ensure that they won't be updated
121self.running_mean if not self.training or self.track_running_stats else None,
122self.running_var if not self.training or self.track_running_stats else None,
123self.weight,
124self.bias,
125bn_training,
126exponential_average_factor,
127self.eps,
128)
129x = self.drop(x)
130x = self.act(x)
131return x
132
133
134class SyncBatchNormAct(nn.SyncBatchNorm):
135# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
136# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
137# but ONLY when used in conjunction with the timm conversion function below.
138# Do not create this module directly or use the PyTorch conversion function.
139def forward(self, x: torch.Tensor) -> torch.Tensor:
140x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
141if hasattr(self, "drop"):
142x = self.drop(x)
143if hasattr(self, "act"):
144x = self.act(x)
145return x
146
147
148def convert_sync_batchnorm(module, process_group=None):
149# convert both BatchNorm and BatchNormAct layers to Synchronized variants
150module_output = module
151if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
152if isinstance(module, BatchNormAct2d):
153# convert timm norm + act layer
154module_output = SyncBatchNormAct(
155module.num_features,
156module.eps,
157module.momentum,
158module.affine,
159module.track_running_stats,
160process_group=process_group,
161)
162# set act and drop attr from the original module
163module_output.act = module.act
164module_output.drop = module.drop
165else:
166# convert standard BatchNorm layers
167module_output = torch.nn.SyncBatchNorm(
168module.num_features,
169module.eps,
170module.momentum,
171module.affine,
172module.track_running_stats,
173process_group,
174)
175if module.affine:
176with torch.no_grad():
177module_output.weight = module.weight
178module_output.bias = module.bias
179module_output.running_mean = module.running_mean
180module_output.running_var = module.running_var
181module_output.num_batches_tracked = module.num_batches_tracked
182if hasattr(module, "qconfig"):
183module_output.qconfig = module.qconfig
184for name, child in module.named_children():
185module_output.add_module(name, convert_sync_batchnorm(child, process_group))
186del module
187return module_output
188
189
190class FrozenBatchNormAct2d(torch.nn.Module):
191"""
192BatchNormAct2d where the batch statistics and the affine parameters are fixed
193
194Args:
195num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
196eps (float): a value added to the denominator for numerical stability. Default: 1e-5
197"""
198
199def __init__(
200self,
201num_features: int,
202eps: float = 1e-5,
203apply_act=True,
204act_layer=nn.ReLU,
205act_kwargs=None,
206inplace=True,
207drop_layer=None,
208):
209super().__init__()
210self.eps = eps
211self.register_buffer("weight", torch.ones(num_features))
212self.register_buffer("bias", torch.zeros(num_features))
213self.register_buffer("running_mean", torch.zeros(num_features))
214self.register_buffer("running_var", torch.ones(num_features))
215
216self.drop = drop_layer() if drop_layer is not None else nn.Identity()
217self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
218
219def _load_from_state_dict(
220self,
221state_dict: dict,
222prefix: str,
223local_metadata: dict,
224strict: bool,
225missing_keys: List[str],
226unexpected_keys: List[str],
227error_msgs: List[str],
228):
229num_batches_tracked_key = prefix + "num_batches_tracked"
230if num_batches_tracked_key in state_dict:
231del state_dict[num_batches_tracked_key]
232
233super()._load_from_state_dict(
234state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
235)
236
237def forward(self, x: torch.Tensor) -> torch.Tensor:
238# move reshapes to the beginning
239# to make it fuser-friendly
240w = self.weight.reshape(1, -1, 1, 1)
241b = self.bias.reshape(1, -1, 1, 1)
242rv = self.running_var.reshape(1, -1, 1, 1)
243rm = self.running_mean.reshape(1, -1, 1, 1)
244scale = w * (rv + self.eps).rsqrt()
245bias = b - rm * scale
246x = x * scale + bias
247x = self.act(self.drop(x))
248return x
249
250def __repr__(self) -> str:
251return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})"
252
253
254def freeze_batch_norm_2d(module):
255"""
256Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers
257of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively.
258
259Args:
260module (torch.nn.Module): Any PyTorch module.
261
262Returns:
263torch.nn.Module: Resulting module
264
265Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
266"""
267res = module
268if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)):
269res = FrozenBatchNormAct2d(module.num_features)
270res.num_features = module.num_features
271res.affine = module.affine
272if module.affine:
273res.weight.data = module.weight.data.clone().detach()
274res.bias.data = module.bias.data.clone().detach()
275res.running_mean.data = module.running_mean.data
276res.running_var.data = module.running_var.data
277res.eps = module.eps
278res.drop = module.drop
279res.act = module.act
280elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
281res = FrozenBatchNorm2d(module.num_features)
282res.num_features = module.num_features
283res.affine = module.affine
284if module.affine:
285res.weight.data = module.weight.data.clone().detach()
286res.bias.data = module.bias.data.clone().detach()
287res.running_mean.data = module.running_mean.data
288res.running_var.data = module.running_var.data
289res.eps = module.eps
290else:
291for name, child in module.named_children():
292new_child = freeze_batch_norm_2d(child)
293if new_child is not child:
294res.add_module(name, new_child)
295return res
296
297
298def unfreeze_batch_norm_2d(module):
299"""
300Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
301of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
302recursively and submodules are converted in place.
303
304Args:
305module (torch.nn.Module): Any PyTorch module.
306
307Returns:
308torch.nn.Module: Resulting module
309
310Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
311"""
312res = module
313if isinstance(module, FrozenBatchNormAct2d):
314res = BatchNormAct2d(module.num_features)
315if module.affine:
316res.weight.data = module.weight.data.clone().detach()
317res.bias.data = module.bias.data.clone().detach()
318res.running_mean.data = module.running_mean.data
319res.running_var.data = module.running_var.data
320res.eps = module.eps
321res.drop = module.drop
322res.act = module.act
323elif isinstance(module, FrozenBatchNorm2d):
324res = torch.nn.BatchNorm2d(module.num_features)
325if module.affine:
326res.weight.data = module.weight.data.clone().detach()
327res.bias.data = module.bias.data.clone().detach()
328res.running_mean.data = module.running_mean.data
329res.running_var.data = module.running_var.data
330res.eps = module.eps
331else:
332for name, child in module.named_children():
333new_child = unfreeze_batch_norm_2d(child)
334if new_child is not child:
335res.add_module(name, new_child)
336return res
337
338
339def _num_groups(num_channels, num_groups, group_size):
340if group_size:
341assert num_channels % group_size == 0
342return num_channels // group_size
343return num_groups
344
345
346class GroupNormAct(nn.GroupNorm):
347# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
348def __init__(
349self,
350num_channels,
351num_groups=32,
352eps=1e-5,
353affine=True,
354group_size=None,
355apply_act=True,
356act_layer=nn.ReLU,
357act_kwargs=None,
358inplace=True,
359drop_layer=None,
360):
361super(GroupNormAct, self).__init__(
362_num_groups(num_channels, num_groups, group_size),
363num_channels,
364eps=eps,
365affine=affine,
366)
367self.drop = drop_layer() if drop_layer is not None else nn.Identity()
368self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
369
370self._fast_norm = is_fast_norm()
371
372def forward(self, x):
373if self._fast_norm:
374x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
375else:
376x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
377x = self.drop(x)
378x = self.act(x)
379return x
380
381
382class GroupNorm1Act(nn.GroupNorm):
383def __init__(
384self,
385num_channels,
386eps=1e-5,
387affine=True,
388apply_act=True,
389act_layer=nn.ReLU,
390act_kwargs=None,
391inplace=True,
392drop_layer=None,
393):
394super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine)
395self.drop = drop_layer() if drop_layer is not None else nn.Identity()
396self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
397
398self._fast_norm = is_fast_norm()
399
400def forward(self, x):
401if self._fast_norm:
402x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
403else:
404x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
405x = self.drop(x)
406x = self.act(x)
407return x
408
409
410class LayerNormAct(nn.LayerNorm):
411def __init__(
412self,
413normalization_shape: Union[int, List[int], torch.Size],
414eps=1e-5,
415affine=True,
416apply_act=True,
417act_layer=nn.ReLU,
418act_kwargs=None,
419inplace=True,
420drop_layer=None,
421):
422super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
423self.drop = drop_layer() if drop_layer is not None else nn.Identity()
424act_layer = get_act_layer(act_layer) # string -> nn.Module
425self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
426
427self._fast_norm = is_fast_norm()
428
429def forward(self, x):
430if self._fast_norm:
431x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
432else:
433x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
434x = self.drop(x)
435x = self.act(x)
436return x
437
438
439class LayerNormAct2d(nn.LayerNorm):
440def __init__(
441self,
442num_channels,
443eps=1e-5,
444affine=True,
445apply_act=True,
446act_layer=nn.ReLU,
447act_kwargs=None,
448inplace=True,
449drop_layer=None,
450):
451super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
452self.drop = drop_layer() if drop_layer is not None else nn.Identity()
453self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
454self._fast_norm = is_fast_norm()
455
456def forward(self, x):
457x = x.permute(0, 2, 3, 1)
458if self._fast_norm:
459x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
460else:
461x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
462x = x.permute(0, 3, 1, 2)
463x = self.drop(x)
464x = self.act(x)
465return x
466