pytorch-image-models

Форк
0
87 строк · 3.3 Кб
1
import torch
2
from torch import nn as nn
3

4
try:
5
    from inplace_abn.functions import inplace_abn, inplace_abn_sync
6
    has_iabn = True
7
except ImportError:
8
    has_iabn = False
9

10
    def inplace_abn(x, weight, bias, running_mean, running_var,
11
                    training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
12
        raise ImportError(
13
            "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")
14

15
    def inplace_abn_sync(**kwargs):
16
        inplace_abn(**kwargs)
17

18

19
class InplaceAbn(nn.Module):
20
    """Activated Batch Normalization
21

22
    This gathers a BatchNorm and an activation function in a single module
23

24
    Parameters
25
    ----------
26
    num_features : int
27
        Number of feature channels in the input and output.
28
    eps : float
29
        Small constant to prevent numerical issues.
30
    momentum : float
31
        Momentum factor applied to compute running statistics.
32
    affine : bool
33
        If `True` apply learned scale and shift transformation after normalization.
34
    act_layer : str or nn.Module type
35
        Name or type of the activation functions, one of: `leaky_relu`, `elu`
36
    act_param : float
37
        Negative slope for the `leaky_relu` activation.
38
    """
39

40
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
41
                 act_layer="leaky_relu", act_param=0.01, drop_layer=None):
42
        super(InplaceAbn, self).__init__()
43
        self.num_features = num_features
44
        self.affine = affine
45
        self.eps = eps
46
        self.momentum = momentum
47
        if apply_act:
48
            if isinstance(act_layer, str):
49
                assert act_layer in ('leaky_relu', 'elu', 'identity', '')
50
                self.act_name = act_layer if act_layer else 'identity'
51
            else:
52
                # convert act layer passed as type to string
53
                if act_layer == nn.ELU:
54
                    self.act_name = 'elu'
55
                elif act_layer == nn.LeakyReLU:
56
                    self.act_name = 'leaky_relu'
57
                elif act_layer is None or act_layer == nn.Identity:
58
                    self.act_name = 'identity'
59
                else:
60
                    assert False, f'Invalid act layer {act_layer.__name__} for IABN'
61
        else:
62
            self.act_name = 'identity'
63
        self.act_param = act_param
64
        if self.affine:
65
            self.weight = nn.Parameter(torch.ones(num_features))
66
            self.bias = nn.Parameter(torch.zeros(num_features))
67
        else:
68
            self.register_parameter('weight', None)
69
            self.register_parameter('bias', None)
70
        self.register_buffer('running_mean', torch.zeros(num_features))
71
        self.register_buffer('running_var', torch.ones(num_features))
72
        self.reset_parameters()
73

74
    def reset_parameters(self):
75
        nn.init.constant_(self.running_mean, 0)
76
        nn.init.constant_(self.running_var, 1)
77
        if self.affine:
78
            nn.init.constant_(self.weight, 1)
79
            nn.init.constant_(self.bias, 0)
80

81
    def forward(self, x):
82
        output = inplace_abn(
83
            x, self.weight, self.bias, self.running_mean, self.running_var,
84
            self.training, self.momentum, self.eps, self.act_name, self.act_param)
85
        if isinstance(output, tuple):
86
            output = output[0]
87
        return output
88

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.