pytorch-image-models

Форк
0
155 строк · 5.9 Кб
1
""" Normalization layers and wrappers
2

3
Norm layer definitions that support fast norm and consistent channel arg order (always first arg).
4

5
Hacked together by / Copyright 2022 Ross Wightman
6
"""
7
import numbers
8
from typing import Tuple
9

10
import torch
11
import torch.nn as nn
12
import torch.nn.functional as F
13

14
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm
15

16

17
class GroupNorm(nn.GroupNorm):
18
    def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
19
        # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
20
        super().__init__(num_groups, num_channels, eps=eps, affine=affine)
21
        self.fast_norm = is_fast_norm()  # can't script unless we have these flags here (no globals)
22

23
    def forward(self, x):
24
        if self.fast_norm:
25
            return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
26
        else:
27
            return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
28

29

30
class GroupNorm1(nn.GroupNorm):
31
    """ Group Normalization with 1 group.
32
    Input: tensor in shape [B, C, *]
33
    """
34

35
    def __init__(self, num_channels, **kwargs):
36
        super().__init__(1, num_channels, **kwargs)
37
        self.fast_norm = is_fast_norm()  # can't script unless we have these flags here (no globals)
38

39
    def forward(self, x: torch.Tensor) -> torch.Tensor:
40
        if self.fast_norm:
41
            return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
42
        else:
43
            return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
44

45

46
class LayerNorm(nn.LayerNorm):
47
    """ LayerNorm w/ fast norm option
48
    """
49
    def __init__(self, num_channels, eps=1e-6, affine=True):
50
        super().__init__(num_channels, eps=eps, elementwise_affine=affine)
51
        self._fast_norm = is_fast_norm()  # can't script unless we have these flags here (no globals)
52

53
    def forward(self, x: torch.Tensor) -> torch.Tensor:
54
        if self._fast_norm:
55
            x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
56
        else:
57
            x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
58
        return x
59

60

61
class LayerNorm2d(nn.LayerNorm):
62
    """ LayerNorm for channels of '2D' spatial NCHW tensors """
63
    def __init__(self, num_channels, eps=1e-6, affine=True):
64
        super().__init__(num_channels, eps=eps, elementwise_affine=affine)
65
        self._fast_norm = is_fast_norm()  # can't script unless we have these flags here (no globals)
66

67
    def forward(self, x: torch.Tensor) -> torch.Tensor:
68
        x = x.permute(0, 2, 3, 1)
69
        if self._fast_norm:
70
            x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
71
        else:
72
            x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
73
        x = x.permute(0, 3, 1, 2)
74
        return x
75

76

77
def _is_contiguous(tensor: torch.Tensor) -> bool:
78
    # jit is oh so lovely :/
79
    if torch.jit.is_scripting():
80
        return tensor.is_contiguous()
81
    else:
82
        return tensor.is_contiguous(memory_format=torch.contiguous_format)
83

84

85
@torch.jit.script
86
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
87
    s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
88
    x = (x - u) * torch.rsqrt(s + eps)
89
    x = x * weight[:, None, None] + bias[:, None, None]
90
    return x
91

92

93
def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
94
    u = x.mean(dim=1, keepdim=True)
95
    s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)
96
    x = (x - u) * torch.rsqrt(s + eps)
97
    x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)
98
    return x
99

100

101
class LayerNormExp2d(nn.LayerNorm):
102
    """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
103

104
    Experimental implementation w/ manual norm for tensors non-contiguous tensors.
105

106
    This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last
107
    layout. However, benefits are not always clear and can perform worse on other GPUs.
108
    """
109

110
    def __init__(self, num_channels, eps=1e-6):
111
        super().__init__(num_channels, eps=eps)
112

113
    def forward(self, x) -> torch.Tensor:
114
        if _is_contiguous(x):
115
            x = F.layer_norm(
116
                x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
117
        else:
118
            x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
119
        return x
120

121

122
class RmsNorm(nn.Module):
123
    """ RmsNorm w/ fast (apex) norm if available
124
    """
125
    __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
126
    normalized_shape: Tuple[int, ...]
127
    eps: float
128
    elementwise_affine: bool
129

130
    def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
131
        factory_kwargs = {'device': device, 'dtype': dtype}
132
        super().__init__()
133
        normalized_shape = channels
134
        if isinstance(normalized_shape, numbers.Integral):
135
            # mypy error: incompatible types in assignment
136
            normalized_shape = (normalized_shape,)  # type: ignore[assignment]
137
        self.normalized_shape = tuple(normalized_shape)  # type: ignore[arg-type]
138
        self.eps = eps
139
        self.elementwise_affine = affine
140
        if self.elementwise_affine:
141
            self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
142
        else:
143
            self.register_parameter('weight', None)
144

145
        self.reset_parameters()
146

147
    def reset_parameters(self) -> None:
148
        if self.elementwise_affine:
149
            nn.init.ones_(self.weight)
150

151
    def forward(self, x: torch.Tensor) -> torch.Tensor:
152
        # NOTE fast norm fallback needs our rms norm impl, so both paths through here.
153
        # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
154
        x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
155
        return x
156

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.