pytorch-image-models
95 строк · 3.7 Кб
1""" NormAct (Normalizaiton + Activation Layer) Factory
2
3Create norm + act combo modules that attempt to be backwards compatible with separate norm + act
4isntances in models. Where these are used it will be possible to swap separate BN + act layers with
5combined modules like IABN or EvoNorms.
6
7Hacked together by / Copyright 2020 Ross Wightman
8"""
9import types10import functools11
12from .evo_norm import *13from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d14from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d15from .inplace_abn import InplaceAbn16
17_NORM_ACT_MAP = dict(18batchnorm=BatchNormAct2d,19batchnorm2d=BatchNormAct2d,20groupnorm=GroupNormAct,21groupnorm1=functools.partial(GroupNormAct, num_groups=1),22layernorm=LayerNormAct,23layernorm2d=LayerNormAct2d,24evonormb0=EvoNorm2dB0,25evonormb1=EvoNorm2dB1,26evonormb2=EvoNorm2dB2,27evonorms0=EvoNorm2dS0,28evonorms0a=EvoNorm2dS0a,29evonorms1=EvoNorm2dS1,30evonorms1a=EvoNorm2dS1a,31evonorms2=EvoNorm2dS2,32evonorms2a=EvoNorm2dS2a,33frn=FilterResponseNormAct2d,34frntlu=FilterResponseNormTlu2d,35inplaceabn=InplaceAbn,36iabn=InplaceAbn,37)
38_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}39# has act_layer arg to define act type
40_NORM_ACT_REQUIRES_ARG = {41BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn}42
43
44def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):45layer = get_norm_act_layer(layer_name, act_layer=act_layer)46layer_instance = layer(num_features, apply_act=apply_act, **kwargs)47if jit:48layer_instance = torch.jit.script(layer_instance)49return layer_instance50
51
52def get_norm_act_layer(norm_layer, act_layer=None):53if norm_layer is None:54return None55assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))56assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))57norm_act_kwargs = {}58
59# unbind partial fn, so args can be rebound later60if isinstance(norm_layer, functools.partial):61norm_act_kwargs.update(norm_layer.keywords)62norm_layer = norm_layer.func63
64if isinstance(norm_layer, str):65if not norm_layer:66return None67layer_name = norm_layer.replace('_', '').lower().split('-')[0]68norm_act_layer = _NORM_ACT_MAP[layer_name]69elif norm_layer in _NORM_ACT_TYPES:70norm_act_layer = norm_layer71elif isinstance(norm_layer, types.FunctionType):72# if function type, must be a lambda/fn that creates a norm_act layer73norm_act_layer = norm_layer74else:75type_name = norm_layer.__name__.lower()76if type_name.startswith('batchnorm'):77norm_act_layer = BatchNormAct2d78elif type_name.startswith('groupnorm'):79norm_act_layer = GroupNormAct80elif type_name.startswith('groupnorm1'):81norm_act_layer = functools.partial(GroupNormAct, num_groups=1)82elif type_name.startswith('layernorm2d'):83norm_act_layer = LayerNormAct2d84elif type_name.startswith('layernorm'):85norm_act_layer = LayerNormAct86else:87assert False, f"No equivalent norm_act layer for {type_name}"88
89if norm_act_layer in _NORM_ACT_REQUIRES_ARG:90# pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.91# In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types92norm_act_kwargs.setdefault('act_layer', act_layer)93if norm_act_kwargs:94norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args95return norm_act_layer96