pytorch-image-models
161 строка · 5.3 Кб
1""" Activation Factory
2Hacked together by / Copyright 2020 Ross Wightman
3"""
4from typing import Union, Callable, Type5
6from .activations import *7from .activations_jit import *8from .activations_me import *9from .config import is_exportable, is_scriptable, is_no_jit10
11# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
12# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
13# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
14_has_silu = 'silu' in dir(torch.nn.functional)15_has_hardswish = 'hardswish' in dir(torch.nn.functional)16_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional)17_has_mish = 'mish' in dir(torch.nn.functional)18
19
20_ACT_FN_DEFAULT = dict(21silu=F.silu if _has_silu else swish,22swish=F.silu if _has_silu else swish,23mish=F.mish if _has_mish else mish,24relu=F.relu,25relu6=F.relu6,26leaky_relu=F.leaky_relu,27elu=F.elu,28celu=F.celu,29selu=F.selu,30gelu=gelu,31gelu_tanh=gelu_tanh,32quick_gelu=quick_gelu,33sigmoid=sigmoid,34tanh=tanh,35hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,36hard_swish=F.hardswish if _has_hardswish else hard_swish,37hard_mish=hard_mish,38)
39
40_ACT_FN_JIT = dict(41silu=F.silu if _has_silu else swish_jit,42swish=F.silu if _has_silu else swish_jit,43mish=F.mish if _has_mish else mish_jit,44hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,45hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,46hard_mish=hard_mish_jit,47)
48
49_ACT_FN_ME = dict(50silu=F.silu if _has_silu else swish_me,51swish=F.silu if _has_silu else swish_me,52mish=F.mish if _has_mish else mish_me,53hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me,54hard_swish=F.hardswish if _has_hardswish else hard_swish_me,55hard_mish=hard_mish_me,56)
57
58_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)59for a in _ACT_FNS:60a.setdefault('hardsigmoid', a.get('hard_sigmoid'))61a.setdefault('hardswish', a.get('hard_swish'))62
63
64_ACT_LAYER_DEFAULT = dict(65silu=nn.SiLU if _has_silu else Swish,66swish=nn.SiLU if _has_silu else Swish,67mish=nn.Mish if _has_mish else Mish,68relu=nn.ReLU,69relu6=nn.ReLU6,70leaky_relu=nn.LeakyReLU,71elu=nn.ELU,72prelu=PReLU,73celu=nn.CELU,74selu=nn.SELU,75gelu=GELU,76gelu_tanh=GELUTanh,77quick_gelu=QuickGELU,78sigmoid=Sigmoid,79tanh=Tanh,80hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,81hard_swish=nn.Hardswish if _has_hardswish else HardSwish,82hard_mish=HardMish,83identity=nn.Identity,84)
85
86_ACT_LAYER_JIT = dict(87silu=nn.SiLU if _has_silu else SwishJit,88swish=nn.SiLU if _has_silu else SwishJit,89mish=nn.Mish if _has_mish else MishJit,90hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,91hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,92hard_mish=HardMishJit,93)
94
95_ACT_LAYER_ME = dict(96silu=nn.SiLU if _has_silu else SwishMe,97swish=nn.SiLU if _has_silu else SwishMe,98mish=nn.Mish if _has_mish else MishMe,99hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe,100hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe,101hard_mish=HardMishMe,102)
103
104_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)105for a in _ACT_LAYERS:106a.setdefault('hardsigmoid', a.get('hard_sigmoid'))107a.setdefault('hardswish', a.get('hard_swish'))108
109
110def get_act_fn(name: Union[Callable, str] = 'relu'):111""" Activation Function Factory112Fetching activation fns by name with this function allows export or torch script friendly
113functions to be returned dynamically based on current config.
114"""
115if not name:116return None117if isinstance(name, Callable):118return name119if not (is_no_jit() or is_exportable() or is_scriptable()):120# If not exporting or scripting the model, first look for a memory-efficient version with121# custom autograd, then fallback122if name in _ACT_FN_ME:123return _ACT_FN_ME[name]124if not (is_no_jit() or is_exportable()):125if name in _ACT_FN_JIT:126return _ACT_FN_JIT[name]127return _ACT_FN_DEFAULT[name]128
129
130def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):131""" Activation Layer Factory132Fetching activation layers by name with this function allows export or torch script friendly
133functions to be returned dynamically based on current config.
134"""
135if name is None:136return None137if not isinstance(name, str):138# callable, module, etc139return name140if not name:141return None142if not (is_no_jit() or is_exportable() or is_scriptable()):143if name in _ACT_LAYER_ME:144return _ACT_LAYER_ME[name]145if not (is_no_jit() or is_exportable()):146if name in _ACT_LAYER_JIT:147return _ACT_LAYER_JIT[name]148return _ACT_LAYER_DEFAULT[name]149
150
151def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs):152act_layer = get_act_layer(name)153if act_layer is None:154return None155if inplace is None:156return act_layer(**kwargs)157try:158return act_layer(inplace=inplace, **kwargs)159except TypeError:160# recover if act layer doesn't have inplace arg161return act_layer(**kwargs)162