pytorch-image-models

Форк
0
233 строки · 10.4 Кб
1
""" Halo Self Attention
2

3
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
4
    - https://arxiv.org/abs/2103.12731
5

6
@misc{2103.12731,
7
Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and
8
    Jonathon Shlens},
9
Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones},
10
Year = {2021},
11
}
12

13
Status:
14
This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.
15
The attention mechanism works but it's slow as implemented.
16

17
Hacked together by / Copyright 2021 Ross Wightman
18
"""
19
from typing import List
20

21
import torch
22
from torch import nn
23
import torch.nn.functional as F
24

25
from .helpers import make_divisible
26
from .weight_init import trunc_normal_
27
from .trace_utils import _assert
28

29

30
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
31
    """ Compute relative logits along one dimension
32

33
    As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
34
    Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
35

36
    Args:
37
        q: (batch, height, width, dim)
38
        rel_k: (2 * window - 1, dim)
39
        permute_mask: permute output dim according to this
40
    """
41
    B, H, W, dim = q.shape
42
    rel_size = rel_k.shape[0]
43
    win_size = (rel_size + 1) // 2
44

45
    x = (q @ rel_k.transpose(-1, -2))
46
    x = x.reshape(-1, W, rel_size)
47

48
    # pad to shift from relative to absolute indexing
49
    x_pad = F.pad(x, [0, 1]).flatten(1)
50
    x_pad = F.pad(x_pad, [0, rel_size - W])
51

52
    # reshape and slice out the padded elements
53
    x_pad = x_pad.reshape(-1, W + 1, rel_size)
54
    x = x_pad[:, :W, win_size - 1:]
55

56
    # reshape and tile
57
    x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1)
58
    return x.permute(permute_mask)
59

60

61
class PosEmbedRel(nn.Module):
62
    """ Relative Position Embedding
63
    As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
64
    Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
65

66
    """
67
    def __init__(self, block_size, win_size, dim_head, scale):
68
        """
69
        Args:
70
            block_size (int): block size
71
            win_size (int): neighbourhood window size
72
            dim_head (int): attention head dim
73
            scale (float): scale factor (for init)
74
        """
75
        super().__init__()
76
        self.block_size = block_size
77
        self.dim_head = dim_head
78
        self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
79
        self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
80

81
    def forward(self, q):
82
        B, BB, HW, _ = q.shape
83

84
        # relative logits in width dimension.
85
        q = q.reshape(-1, self.block_size, self.block_size, self.dim_head)
86
        rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
87

88
        # relative logits in height dimension.
89
        q = q.transpose(1, 2)
90
        rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
91

92
        rel_logits = rel_logits_h + rel_logits_w
93
        rel_logits = rel_logits.reshape(B, BB, HW, -1)
94
        return rel_logits
95

96

97
class HaloAttn(nn.Module):
98
    """ Halo Attention
99

100
    Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
101
        - https://arxiv.org/abs/2103.12731
102

103
    The internal dimensions of the attention module are controlled by the interaction of several arguments.
104
      * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
105
      * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
106
      * the query and key (qk) dimensions are determined by
107
        * num_heads * dim_head if dim_head is not None
108
        * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
109
      * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
110

111
    Args:
112
        dim (int): input dimension to the module
113
        dim_out (int): output dimension of the module, same as dim if not set
114
        feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda)
115
        stride: output stride of the module, query downscaled if > 1 (default: 1).
116
        num_heads: parallel attention heads (default: 8).
117
        dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
118
        block_size (int): size of blocks. (default: 8)
119
        halo_size (int): size of halo overlap. (default: 3)
120
        qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
121
        qkv_bias (bool) : add bias to q, k, and v projections
122
        avg_down (bool): use average pool downsample instead of strided query blocks
123
        scale_pos_embed (bool): scale the position embedding as well as Q @ K
124
    """
125
    def __init__(
126
            self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3,
127
            qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False):
128
        super().__init__()
129
        dim_out = dim_out or dim
130
        assert dim_out % num_heads == 0
131
        assert stride in (1, 2)
132
        self.num_heads = num_heads
133
        self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
134
        self.dim_head_v = dim_out // self.num_heads
135
        self.dim_out_qk = num_heads * self.dim_head_qk
136
        self.dim_out_v = num_heads * self.dim_head_v
137
        self.scale = self.dim_head_qk ** -0.5
138
        self.scale_pos_embed = scale_pos_embed
139
        self.block_size = self.block_size_ds = block_size
140
        self.halo_size = halo_size
141
        self.win_size = block_size + halo_size * 2  # neighbourhood window size
142
        self.block_stride = 1
143
        use_avg_pool = False
144
        if stride > 1:
145
            use_avg_pool = avg_down or block_size % stride != 0
146
            self.block_stride = 1 if use_avg_pool else stride
147
            self.block_size_ds = self.block_size // self.block_stride
148

149
        # FIXME not clear if this stride behaviour is what the paper intended
150
        # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
151
        # data in unfolded block form. I haven't wrapped my head around how that'd look.
152
        self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias)
153
        self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)
154

155
        self.pos_embed = PosEmbedRel(
156
            block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale)
157

158
        self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity()
159

160
        self.reset_parameters()
161

162
    def reset_parameters(self):
163
        std = self.q.weight.shape[1] ** -0.5  # fan-in
164
        trunc_normal_(self.q.weight, std=std)
165
        trunc_normal_(self.kv.weight, std=std)
166
        trunc_normal_(self.pos_embed.height_rel, std=self.scale)
167
        trunc_normal_(self.pos_embed.width_rel, std=self.scale)
168

169
    def forward(self, x):
170
        B, C, H, W = x.shape
171
        _assert(H % self.block_size == 0, '')
172
        _assert(W % self.block_size == 0, '')
173
        num_h_blocks = H // self.block_size
174
        num_w_blocks = W // self.block_size
175
        num_blocks = num_h_blocks * num_w_blocks
176

177
        q = self.q(x)
178
        # unfold
179
        q = q.reshape(
180
            -1, self.dim_head_qk,
181
            num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4)
182
        # B, num_heads * dim_head * block_size ** 2, num_blocks
183
        q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)
184
        # B * num_heads, num_blocks, block_size ** 2, dim_head
185

186
        kv = self.kv(x)
187
        # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not
188
        # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach.
189
        # FIXME figure out how to switch impl between this and conv2d if XLA being used.
190
        kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
191
        kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
192
            B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)
193
        k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
194
        # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
195

196
        if self.scale_pos_embed:
197
            attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale
198
        else:
199
            attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q)
200
        # B * num_heads, num_blocks, block_size ** 2, win_size ** 2
201
        attn = attn.softmax(dim=-1)
202

203
        out = (attn @ v).transpose(1, 3)  # B * num_heads, dim_head_v, block_size ** 2, num_blocks
204
        # fold
205
        out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks)
206
        out = out.permute(0, 3, 1, 4, 2).contiguous().view(
207
            B, self.dim_out_v, H // self.block_stride, W // self.block_stride)
208
        # B, dim_out, H // block_stride, W // block_stride
209
        out = self.pool(out)
210
        return out
211

212

213
""" Three alternatives for overlapping windows.
214

215
`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()
216

217
    if is_xla:
218
        # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is
219
        # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment.
220
        WW = self.win_size ** 2
221
        pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size)
222
        kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size)
223
    elif self.stride_tricks:
224
        kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
225
        kv = kv.as_strided((
226
            B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
227
            stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
228
    else:
229
        kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
230

231
    kv = kv.reshape(
232
       B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)
233
"""
234

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.