pytorch-image-models

Форк
0
/
lambda_layer.py 
134 строки · 5.8 Кб
1
""" Lambda Layer
2

3
Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
4
    - https://arxiv.org/abs/2102.08602
5

6
@misc{2102.08602,
7
Author = {Irwan Bello},
8
Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention},
9
Year = {2021},
10
}
11

12
Status:
13
This impl is a WIP. Code snippets in the paper were used as reference but
14
good chance some details are missing/wrong.
15

16
I've only implemented local lambda conv based pos embeddings.
17

18
For a PyTorch impl that includes other embedding options checkout
19
https://github.com/lucidrains/lambda-networks
20

21
Hacked together by / Copyright 2021 Ross Wightman
22
"""
23
import torch
24
from torch import nn
25
import torch.nn.functional as F
26

27
from .grid import ndgrid
28
from .helpers import to_2tuple, make_divisible
29
from .weight_init import trunc_normal_
30

31

32
def rel_pos_indices(size):
33
    size = to_2tuple(size)
34
    pos = torch.stack(ndgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
35
    rel_pos = pos[:, None, :] - pos[:, :, None]
36
    rel_pos[0] += size[0] - 1
37
    rel_pos[1] += size[1] - 1
38
    return rel_pos  # 2, H * W, H * W
39

40

41
class LambdaLayer(nn.Module):
42
    """Lambda Layer
43

44
    Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
45
        - https://arxiv.org/abs/2102.08602
46

47
    NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
48

49
    The internal dimensions of the lambda module are controlled via the interaction of several arguments.
50
      * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
51
      * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
52
      * the query (q) and key (k) dimension are determined by
53
        * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None
54
        * q = num_heads * dim_head, k = dim_head
55
      * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set
56

57
    Args:
58
        dim (int): input dimension to the module
59
        dim_out (int): output dimension of the module, same as dim if not set
60
        feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W
61
        stride (int): output stride of the module, avg pool used if stride == 2
62
        num_heads (int): parallel attention heads.
63
        dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
64
        r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)
65
        qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
66
        qkv_bias (bool): add bias to q, k, and v projections
67
    """
68
    def __init__(
69
            self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9,
70
            qk_ratio=1.0, qkv_bias=False):
71
        super().__init__()
72
        dim_out = dim_out or dim
73
        assert dim_out % num_heads == 0, ' should be divided by num_heads'
74
        self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
75
        self.num_heads = num_heads
76
        self.dim_v = dim_out // num_heads
77

78
        self.qkv = nn.Conv2d(
79
            dim,
80
            num_heads * self.dim_qk + self.dim_qk + self.dim_v,
81
            kernel_size=1, bias=qkv_bias)
82
        self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk)
83
        self.norm_v = nn.BatchNorm2d(self.dim_v)
84

85
        if r is not None:
86
            # local lambda convolution for pos
87
            self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0))
88
            self.pos_emb = None
89
            self.rel_pos_indices = None
90
        else:
91
            # relative pos embedding
92
            assert feat_size is not None
93
            feat_size = to_2tuple(feat_size)
94
            rel_size = [2 * s - 1 for s in feat_size]
95
            self.conv_lambda = None
96
            self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk))
97
            self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False)
98

99
        self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
100

101
        self.reset_parameters()
102

103
    def reset_parameters(self):
104
        trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)  # fan-in
105
        if self.conv_lambda is not None:
106
            trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5)
107
        if self.pos_emb is not None:
108
            trunc_normal_(self.pos_emb, std=.02)
109

110
    def forward(self, x):
111
        B, C, H, W = x.shape
112
        M = H * W
113
        qkv = self.qkv(x)
114
        q, k, v = torch.split(qkv, [
115
            self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1)
116
        q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2)  # B, num_heads, M, K
117
        v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2)  # B, M, V
118
        k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1)  # B, K, M
119

120
        content_lam = k @ v  # B, K, V
121
        content_out = q @ content_lam.unsqueeze(1)  # B, num_heads, M, V
122

123
        if self.pos_emb is None:
124
            position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v))  # B, H, W, V, K
125
            position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3)  # B, 1, M, K, V
126
        else:
127
            # FIXME relative pos embedding path not fully verified
128
            pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)
129
            position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1)  # B, 1, M, K, V
130
        position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2)  # B, num_heads, M, V
131

132
        out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W)  # B, C (num_heads * V), H, W
133
        out = self.pool(out)
134
        return out
135

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.