pytorch-image-models

Форк
0
/
create_norm_act.py 
95 строк · 3.7 Кб
1
""" NormAct (Normalizaiton + Activation Layer) Factory
2

3
Create norm + act combo modules that attempt to be backwards compatible with separate norm + act
4
isntances in models. Where these are used it will be possible to swap separate BN + act layers with
5
combined modules like IABN or EvoNorms.
6

7
Hacked together by / Copyright 2020 Ross Wightman
8
"""
9
import types
10
import functools
11

12
from .evo_norm import *
13
from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
14
from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d
15
from .inplace_abn import InplaceAbn
16

17
_NORM_ACT_MAP = dict(
18
    batchnorm=BatchNormAct2d,
19
    batchnorm2d=BatchNormAct2d,
20
    groupnorm=GroupNormAct,
21
    groupnorm1=functools.partial(GroupNormAct, num_groups=1),
22
    layernorm=LayerNormAct,
23
    layernorm2d=LayerNormAct2d,
24
    evonormb0=EvoNorm2dB0,
25
    evonormb1=EvoNorm2dB1,
26
    evonormb2=EvoNorm2dB2,
27
    evonorms0=EvoNorm2dS0,
28
    evonorms0a=EvoNorm2dS0a,
29
    evonorms1=EvoNorm2dS1,
30
    evonorms1a=EvoNorm2dS1a,
31
    evonorms2=EvoNorm2dS2,
32
    evonorms2a=EvoNorm2dS2a,
33
    frn=FilterResponseNormAct2d,
34
    frntlu=FilterResponseNormTlu2d,
35
    inplaceabn=InplaceAbn,
36
    iabn=InplaceAbn,
37
)
38
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
39
# has act_layer arg to define act type
40
_NORM_ACT_REQUIRES_ARG = {
41
    BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn}
42

43

44
def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
45
    layer = get_norm_act_layer(layer_name, act_layer=act_layer)
46
    layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
47
    if jit:
48
        layer_instance = torch.jit.script(layer_instance)
49
    return layer_instance
50

51

52
def get_norm_act_layer(norm_layer, act_layer=None):
53
    if norm_layer is None:
54
        return None
55
    assert isinstance(norm_layer, (type, str,  types.FunctionType, functools.partial))
56
    assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
57
    norm_act_kwargs = {}
58

59
    # unbind partial fn, so args can be rebound later
60
    if isinstance(norm_layer, functools.partial):
61
        norm_act_kwargs.update(norm_layer.keywords)
62
        norm_layer = norm_layer.func
63

64
    if isinstance(norm_layer, str):
65
        if not norm_layer:
66
            return None
67
        layer_name = norm_layer.replace('_', '').lower().split('-')[0]
68
        norm_act_layer = _NORM_ACT_MAP[layer_name]
69
    elif norm_layer in _NORM_ACT_TYPES:
70
        norm_act_layer = norm_layer
71
    elif isinstance(norm_layer,  types.FunctionType):
72
        # if function type, must be a lambda/fn that creates a norm_act layer
73
        norm_act_layer = norm_layer
74
    else:
75
        type_name = norm_layer.__name__.lower()
76
        if type_name.startswith('batchnorm'):
77
            norm_act_layer = BatchNormAct2d
78
        elif type_name.startswith('groupnorm'):
79
            norm_act_layer = GroupNormAct
80
        elif type_name.startswith('groupnorm1'):
81
            norm_act_layer = functools.partial(GroupNormAct, num_groups=1)
82
        elif type_name.startswith('layernorm2d'):
83
            norm_act_layer = LayerNormAct2d
84
        elif type_name.startswith('layernorm'):
85
            norm_act_layer = LayerNormAct
86
        else:
87
            assert False, f"No equivalent norm_act layer for {type_name}"
88

89
    if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
90
        # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
91
        # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
92
        norm_act_kwargs.setdefault('act_layer', act_layer)
93
    if norm_act_kwargs:
94
        norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs)  # bind/rebind args
95
    return norm_act_layer
96

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.