pytorch-image-models

Форк
0
161 строка · 5.3 Кб
1
""" Activation Factory
2
Hacked together by / Copyright 2020 Ross Wightman
3
"""
4
from typing import Union, Callable, Type
5

6
from .activations import *
7
from .activations_jit import *
8
from .activations_me import *
9
from .config import is_exportable, is_scriptable, is_no_jit
10

11
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
12
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
13
# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
14
_has_silu = 'silu' in dir(torch.nn.functional)
15
_has_hardswish = 'hardswish' in dir(torch.nn.functional)
16
_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional)
17
_has_mish = 'mish' in dir(torch.nn.functional)
18

19

20
_ACT_FN_DEFAULT = dict(
21
    silu=F.silu if _has_silu else swish,
22
    swish=F.silu if _has_silu else swish,
23
    mish=F.mish if _has_mish else mish,
24
    relu=F.relu,
25
    relu6=F.relu6,
26
    leaky_relu=F.leaky_relu,
27
    elu=F.elu,
28
    celu=F.celu,
29
    selu=F.selu,
30
    gelu=gelu,
31
    gelu_tanh=gelu_tanh,
32
    quick_gelu=quick_gelu,
33
    sigmoid=sigmoid,
34
    tanh=tanh,
35
    hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
36
    hard_swish=F.hardswish if _has_hardswish else hard_swish,
37
    hard_mish=hard_mish,
38
)
39

40
_ACT_FN_JIT = dict(
41
    silu=F.silu if _has_silu else swish_jit,
42
    swish=F.silu if _has_silu else swish_jit,
43
    mish=F.mish if _has_mish else mish_jit,
44
    hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
45
    hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
46
    hard_mish=hard_mish_jit,
47
)
48

49
_ACT_FN_ME = dict(
50
    silu=F.silu if _has_silu else swish_me,
51
    swish=F.silu if _has_silu else swish_me,
52
    mish=F.mish if _has_mish else mish_me,
53
    hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me,
54
    hard_swish=F.hardswish if _has_hardswish else hard_swish_me,
55
    hard_mish=hard_mish_me,
56
)
57

58
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
59
for a in _ACT_FNS:
60
    a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
61
    a.setdefault('hardswish', a.get('hard_swish'))
62

63

64
_ACT_LAYER_DEFAULT = dict(
65
    silu=nn.SiLU if _has_silu else Swish,
66
    swish=nn.SiLU if _has_silu else Swish,
67
    mish=nn.Mish if _has_mish else Mish,
68
    relu=nn.ReLU,
69
    relu6=nn.ReLU6,
70
    leaky_relu=nn.LeakyReLU,
71
    elu=nn.ELU,
72
    prelu=PReLU,
73
    celu=nn.CELU,
74
    selu=nn.SELU,
75
    gelu=GELU,
76
    gelu_tanh=GELUTanh,
77
    quick_gelu=QuickGELU,
78
    sigmoid=Sigmoid,
79
    tanh=Tanh,
80
    hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
81
    hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
82
    hard_mish=HardMish,
83
    identity=nn.Identity,
84
)
85

86
_ACT_LAYER_JIT = dict(
87
    silu=nn.SiLU if _has_silu else SwishJit,
88
    swish=nn.SiLU if _has_silu else SwishJit,
89
    mish=nn.Mish if _has_mish else MishJit,
90
    hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
91
    hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
92
    hard_mish=HardMishJit,
93
)
94

95
_ACT_LAYER_ME = dict(
96
    silu=nn.SiLU if _has_silu else SwishMe,
97
    swish=nn.SiLU if _has_silu else SwishMe,
98
    mish=nn.Mish if _has_mish else MishMe,
99
    hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe,
100
    hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe,
101
    hard_mish=HardMishMe,
102
)
103

104
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
105
for a in _ACT_LAYERS:
106
    a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
107
    a.setdefault('hardswish', a.get('hard_swish'))
108

109

110
def get_act_fn(name: Union[Callable, str] = 'relu'):
111
    """ Activation Function Factory
112
    Fetching activation fns by name with this function allows export or torch script friendly
113
    functions to be returned dynamically based on current config.
114
    """
115
    if not name:
116
        return None
117
    if isinstance(name, Callable):
118
        return name
119
    if not (is_no_jit() or is_exportable() or is_scriptable()):
120
        # If not exporting or scripting the model, first look for a memory-efficient version with
121
        # custom autograd, then fallback
122
        if name in _ACT_FN_ME:
123
            return _ACT_FN_ME[name]
124
    if not (is_no_jit() or is_exportable()):
125
        if name in _ACT_FN_JIT:
126
            return _ACT_FN_JIT[name]
127
    return _ACT_FN_DEFAULT[name]
128

129

130
def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
131
    """ Activation Layer Factory
132
    Fetching activation layers by name with this function allows export or torch script friendly
133
    functions to be returned dynamically based on current config.
134
    """
135
    if name is None:
136
        return None
137
    if not isinstance(name, str):
138
        # callable, module, etc
139
        return name
140
    if not name:
141
        return None
142
    if not (is_no_jit() or is_exportable() or is_scriptable()):
143
        if name in _ACT_LAYER_ME:
144
            return _ACT_LAYER_ME[name]
145
    if not (is_no_jit() or is_exportable()):
146
        if name in _ACT_LAYER_JIT:
147
            return _ACT_LAYER_JIT[name]
148
    return _ACT_LAYER_DEFAULT[name]
149

150

151
def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs):
152
    act_layer = get_act_layer(name)
153
    if act_layer is None:
154
        return None
155
    if inplace is None:
156
        return act_layer(**kwargs)
157
    try:
158
        return act_layer(inplace=inplace, **kwargs)
159
    except TypeError:
160
        # recover if act layer doesn't have inplace arg
161
        return act_layer(**kwargs)
162

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.