pytorch-image-models
87 строк · 3.3 Кб
1import torch2from torch import nn as nn3
4try:5from inplace_abn.functions import inplace_abn, inplace_abn_sync6has_iabn = True7except ImportError:8has_iabn = False9
10def inplace_abn(x, weight, bias, running_mean, running_var,11training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):12raise ImportError(13"Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")14
15def inplace_abn_sync(**kwargs):16inplace_abn(**kwargs)17
18
19class InplaceAbn(nn.Module):20"""Activated Batch Normalization21
22This gathers a BatchNorm and an activation function in a single module
23
24Parameters
25----------
26num_features : int
27Number of feature channels in the input and output.
28eps : float
29Small constant to prevent numerical issues.
30momentum : float
31Momentum factor applied to compute running statistics.
32affine : bool
33If `True` apply learned scale and shift transformation after normalization.
34act_layer : str or nn.Module type
35Name or type of the activation functions, one of: `leaky_relu`, `elu`
36act_param : float
37Negative slope for the `leaky_relu` activation.
38"""
39
40def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,41act_layer="leaky_relu", act_param=0.01, drop_layer=None):42super(InplaceAbn, self).__init__()43self.num_features = num_features44self.affine = affine45self.eps = eps46self.momentum = momentum47if apply_act:48if isinstance(act_layer, str):49assert act_layer in ('leaky_relu', 'elu', 'identity', '')50self.act_name = act_layer if act_layer else 'identity'51else:52# convert act layer passed as type to string53if act_layer == nn.ELU:54self.act_name = 'elu'55elif act_layer == nn.LeakyReLU:56self.act_name = 'leaky_relu'57elif act_layer is None or act_layer == nn.Identity:58self.act_name = 'identity'59else:60assert False, f'Invalid act layer {act_layer.__name__} for IABN'61else:62self.act_name = 'identity'63self.act_param = act_param64if self.affine:65self.weight = nn.Parameter(torch.ones(num_features))66self.bias = nn.Parameter(torch.zeros(num_features))67else:68self.register_parameter('weight', None)69self.register_parameter('bias', None)70self.register_buffer('running_mean', torch.zeros(num_features))71self.register_buffer('running_var', torch.ones(num_features))72self.reset_parameters()73
74def reset_parameters(self):75nn.init.constant_(self.running_mean, 0)76nn.init.constant_(self.running_var, 1)77if self.affine:78nn.init.constant_(self.weight, 1)79nn.init.constant_(self.bias, 0)80
81def forward(self, x):82output = inplace_abn(83x, self.weight, self.bias, self.running_mean, self.running_var,84self.training, self.momentum, self.eps, self.act_name, self.act_param)85if isinstance(output, tuple):86output = output[0]87return output88