pytorch-image-models
89 строк · 3.4 Кб
1""" Attention Factory
2
3Hacked together by / Copyright 2021 Ross Wightman
4"""
5import torch6from functools import partial7
8from .bottleneck_attn import BottleneckAttn9from .cbam import CbamModule, LightCbamModule10from .eca import EcaModule, CecaModule11from .gather_excite import GatherExcite12from .global_context import GlobalContext13from .halo_attn import HaloAttn14from .lambda_layer import LambdaLayer15from .non_local_attn import NonLocalAttn, BatNonLocalAttn16from .selective_kernel import SelectiveKernel17from .split_attn import SplitAttn18from .squeeze_excite import SEModule, EffectiveSEModule19
20
21def get_attn(attn_type):22if isinstance(attn_type, torch.nn.Module):23return attn_type24module_cls = None25if attn_type:26if isinstance(attn_type, str):27attn_type = attn_type.lower()28# Lightweight attention modules (channel and/or coarse spatial).29# Typically added to existing network architecture blocks in addition to existing convolutions.30if attn_type == 'se':31module_cls = SEModule32elif attn_type == 'ese':33module_cls = EffectiveSEModule34elif attn_type == 'eca':35module_cls = EcaModule36elif attn_type == 'ecam':37module_cls = partial(EcaModule, use_mlp=True)38elif attn_type == 'ceca':39module_cls = CecaModule40elif attn_type == 'ge':41module_cls = GatherExcite42elif attn_type == 'gc':43module_cls = GlobalContext44elif attn_type == 'gca':45module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)46elif attn_type == 'cbam':47module_cls = CbamModule48elif attn_type == 'lcbam':49module_cls = LightCbamModule50
51# Attention / attention-like modules w/ significant params52# Typically replace some of the existing workhorse convs in a network architecture.53# All of these accept a stride argument and can spatially downsample the input.54elif attn_type == 'sk':55module_cls = SelectiveKernel56elif attn_type == 'splat':57module_cls = SplitAttn58
59# Self-attention / attention-like modules w/ significant compute and/or params60# Typically replace some of the existing workhorse convs in a network architecture.61# All of these accept a stride argument and can spatially downsample the input.62elif attn_type == 'lambda':63return LambdaLayer64elif attn_type == 'bottleneck':65return BottleneckAttn66elif attn_type == 'halo':67return HaloAttn68elif attn_type == 'nl':69module_cls = NonLocalAttn70elif attn_type == 'bat':71module_cls = BatNonLocalAttn72
73# Woops!74else:75assert False, "Invalid attn module (%s)" % attn_type76elif isinstance(attn_type, bool):77if attn_type:78module_cls = SEModule79else:80module_cls = attn_type81return module_cls82
83
84def create_attn(attn_type, channels, **kwargs):85module_cls = get_attn(attn_type)86if module_cls is not None:87# NOTE: it's expected the first (positional) argument of all attention layers is the # input channels88return module_cls(channels, **kwargs)89return None90