pytorch-image-models

Форк
0
89 строк · 3.4 Кб
1
""" Attention Factory
2

3
Hacked together by / Copyright 2021 Ross Wightman
4
"""
5
import torch
6
from functools import partial
7

8
from .bottleneck_attn import BottleneckAttn
9
from .cbam import CbamModule, LightCbamModule
10
from .eca import EcaModule, CecaModule
11
from .gather_excite import GatherExcite
12
from .global_context import GlobalContext
13
from .halo_attn import HaloAttn
14
from .lambda_layer import LambdaLayer
15
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
16
from .selective_kernel import SelectiveKernel
17
from .split_attn import SplitAttn
18
from .squeeze_excite import SEModule, EffectiveSEModule
19

20

21
def get_attn(attn_type):
22
    if isinstance(attn_type, torch.nn.Module):
23
        return attn_type
24
    module_cls = None
25
    if attn_type:
26
        if isinstance(attn_type, str):
27
            attn_type = attn_type.lower()
28
            # Lightweight attention modules (channel and/or coarse spatial).
29
            # Typically added to existing network architecture blocks in addition to existing convolutions.
30
            if attn_type == 'se':
31
                module_cls = SEModule
32
            elif attn_type == 'ese':
33
                module_cls = EffectiveSEModule
34
            elif attn_type == 'eca':
35
                module_cls = EcaModule
36
            elif attn_type == 'ecam':
37
                module_cls = partial(EcaModule, use_mlp=True)
38
            elif attn_type == 'ceca':
39
                module_cls = CecaModule
40
            elif attn_type == 'ge':
41
                module_cls = GatherExcite
42
            elif attn_type == 'gc':
43
                module_cls = GlobalContext
44
            elif attn_type == 'gca':
45
                module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
46
            elif attn_type == 'cbam':
47
                module_cls = CbamModule
48
            elif attn_type == 'lcbam':
49
                module_cls = LightCbamModule
50

51
            # Attention / attention-like modules w/ significant params
52
            # Typically replace some of the existing workhorse convs in a network architecture.
53
            # All of these accept a stride argument and can spatially downsample the input.
54
            elif attn_type == 'sk':
55
                module_cls = SelectiveKernel
56
            elif attn_type == 'splat':
57
                module_cls = SplitAttn
58

59
            # Self-attention / attention-like modules w/ significant compute and/or params
60
            # Typically replace some of the existing workhorse convs in a network architecture.
61
            # All of these accept a stride argument and can spatially downsample the input.
62
            elif attn_type == 'lambda':
63
                return LambdaLayer
64
            elif attn_type == 'bottleneck':
65
                return BottleneckAttn
66
            elif attn_type == 'halo':
67
                return HaloAttn
68
            elif attn_type == 'nl':
69
                module_cls = NonLocalAttn
70
            elif attn_type == 'bat':
71
                module_cls = BatNonLocalAttn
72

73
            # Woops!
74
            else:
75
                assert False, "Invalid attn module (%s)" % attn_type
76
        elif isinstance(attn_type, bool):
77
            if attn_type:
78
                module_cls = SEModule
79
        else:
80
            module_cls = attn_type
81
    return module_cls
82

83

84
def create_attn(attn_type, channels, **kwargs):
85
    module_cls = get_attn(attn_type)
86
    if module_cls is not None:
87
        # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
88
        return module_cls(channels, **kwargs)
89
    return None
90

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.