pytorch-image-models
155 строк · 5.9 Кб
1""" Normalization layers and wrappers
2
3Norm layer definitions that support fast norm and consistent channel arg order (always first arg).
4
5Hacked together by / Copyright 2022 Ross Wightman
6"""
7import numbers8from typing import Tuple9
10import torch11import torch.nn as nn12import torch.nn.functional as F13
14from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm15
16
17class GroupNorm(nn.GroupNorm):18def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):19# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN20super().__init__(num_groups, num_channels, eps=eps, affine=affine)21self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)22
23def forward(self, x):24if self.fast_norm:25return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)26else:27return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)28
29
30class GroupNorm1(nn.GroupNorm):31""" Group Normalization with 1 group.32Input: tensor in shape [B, C, *]
33"""
34
35def __init__(self, num_channels, **kwargs):36super().__init__(1, num_channels, **kwargs)37self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)38
39def forward(self, x: torch.Tensor) -> torch.Tensor:40if self.fast_norm:41return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)42else:43return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)44
45
46class LayerNorm(nn.LayerNorm):47""" LayerNorm w/ fast norm option48"""
49def __init__(self, num_channels, eps=1e-6, affine=True):50super().__init__(num_channels, eps=eps, elementwise_affine=affine)51self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)52
53def forward(self, x: torch.Tensor) -> torch.Tensor:54if self._fast_norm:55x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)56else:57x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)58return x59
60
61class LayerNorm2d(nn.LayerNorm):62""" LayerNorm for channels of '2D' spatial NCHW tensors """63def __init__(self, num_channels, eps=1e-6, affine=True):64super().__init__(num_channels, eps=eps, elementwise_affine=affine)65self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)66
67def forward(self, x: torch.Tensor) -> torch.Tensor:68x = x.permute(0, 2, 3, 1)69if self._fast_norm:70x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)71else:72x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)73x = x.permute(0, 3, 1, 2)74return x75
76
77def _is_contiguous(tensor: torch.Tensor) -> bool:78# jit is oh so lovely :/79if torch.jit.is_scripting():80return tensor.is_contiguous()81else:82return tensor.is_contiguous(memory_format=torch.contiguous_format)83
84
85@torch.jit.script86def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):87s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)88x = (x - u) * torch.rsqrt(s + eps)89x = x * weight[:, None, None] + bias[:, None, None]90return x91
92
93def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):94u = x.mean(dim=1, keepdim=True)95s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)96x = (x - u) * torch.rsqrt(s + eps)97x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)98return x99
100
101class LayerNormExp2d(nn.LayerNorm):102""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).103
104Experimental implementation w/ manual norm for tensors non-contiguous tensors.
105
106This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last
107layout. However, benefits are not always clear and can perform worse on other GPUs.
108"""
109
110def __init__(self, num_channels, eps=1e-6):111super().__init__(num_channels, eps=eps)112
113def forward(self, x) -> torch.Tensor:114if _is_contiguous(x):115x = F.layer_norm(116x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)117else:118x = _layer_norm_cf(x, self.weight, self.bias, self.eps)119return x120
121
122class RmsNorm(nn.Module):123""" RmsNorm w/ fast (apex) norm if available124"""
125__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']126normalized_shape: Tuple[int, ...]127eps: float128elementwise_affine: bool129
130def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:131factory_kwargs = {'device': device, 'dtype': dtype}132super().__init__()133normalized_shape = channels134if isinstance(normalized_shape, numbers.Integral):135# mypy error: incompatible types in assignment136normalized_shape = (normalized_shape,) # type: ignore[assignment]137self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]138self.eps = eps139self.elementwise_affine = affine140if self.elementwise_affine:141self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))142else:143self.register_parameter('weight', None)144
145self.reset_parameters()146
147def reset_parameters(self) -> None:148if self.elementwise_affine:149nn.init.ones_(self.weight)150
151def forward(self, x: torch.Tensor) -> torch.Tensor:152# NOTE fast norm fallback needs our rms norm impl, so both paths through here.153# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.154x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)155return x156