pytorch-image-models
233 строки · 10.4 Кб
1""" Halo Self Attention
2
3Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
4- https://arxiv.org/abs/2103.12731
5
6@misc{2103.12731,
7Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and
8Jonathon Shlens},
9Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones},
10Year = {2021},
11}
12
13Status:
14This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.
15The attention mechanism works but it's slow as implemented.
16
17Hacked together by / Copyright 2021 Ross Wightman
18"""
19from typing import List20
21import torch22from torch import nn23import torch.nn.functional as F24
25from .helpers import make_divisible26from .weight_init import trunc_normal_27from .trace_utils import _assert28
29
30def rel_logits_1d(q, rel_k, permute_mask: List[int]):31""" Compute relative logits along one dimension32
33As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
34Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
35
36Args:
37q: (batch, height, width, dim)
38rel_k: (2 * window - 1, dim)
39permute_mask: permute output dim according to this
40"""
41B, H, W, dim = q.shape42rel_size = rel_k.shape[0]43win_size = (rel_size + 1) // 244
45x = (q @ rel_k.transpose(-1, -2))46x = x.reshape(-1, W, rel_size)47
48# pad to shift from relative to absolute indexing49x_pad = F.pad(x, [0, 1]).flatten(1)50x_pad = F.pad(x_pad, [0, rel_size - W])51
52# reshape and slice out the padded elements53x_pad = x_pad.reshape(-1, W + 1, rel_size)54x = x_pad[:, :W, win_size - 1:]55
56# reshape and tile57x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1)58return x.permute(permute_mask)59
60
61class PosEmbedRel(nn.Module):62""" Relative Position Embedding63As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
64Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
65
66"""
67def __init__(self, block_size, win_size, dim_head, scale):68"""69Args:
70block_size (int): block size
71win_size (int): neighbourhood window size
72dim_head (int): attention head dim
73scale (float): scale factor (for init)
74"""
75super().__init__()76self.block_size = block_size77self.dim_head = dim_head78self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)79self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)80
81def forward(self, q):82B, BB, HW, _ = q.shape83
84# relative logits in width dimension.85q = q.reshape(-1, self.block_size, self.block_size, self.dim_head)86rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))87
88# relative logits in height dimension.89q = q.transpose(1, 2)90rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))91
92rel_logits = rel_logits_h + rel_logits_w93rel_logits = rel_logits.reshape(B, BB, HW, -1)94return rel_logits95
96
97class HaloAttn(nn.Module):98""" Halo Attention99
100Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
101- https://arxiv.org/abs/2103.12731
102
103The internal dimensions of the attention module are controlled by the interaction of several arguments.
104* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
105* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
106* the query and key (qk) dimensions are determined by
107* num_heads * dim_head if dim_head is not None
108* num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
109* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
110
111Args:
112dim (int): input dimension to the module
113dim_out (int): output dimension of the module, same as dim if not set
114feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda)
115stride: output stride of the module, query downscaled if > 1 (default: 1).
116num_heads: parallel attention heads (default: 8).
117dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
118block_size (int): size of blocks. (default: 8)
119halo_size (int): size of halo overlap. (default: 3)
120qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
121qkv_bias (bool) : add bias to q, k, and v projections
122avg_down (bool): use average pool downsample instead of strided query blocks
123scale_pos_embed (bool): scale the position embedding as well as Q @ K
124"""
125def __init__(126self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3,127qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False):128super().__init__()129dim_out = dim_out or dim130assert dim_out % num_heads == 0131assert stride in (1, 2)132self.num_heads = num_heads133self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads134self.dim_head_v = dim_out // self.num_heads135self.dim_out_qk = num_heads * self.dim_head_qk136self.dim_out_v = num_heads * self.dim_head_v137self.scale = self.dim_head_qk ** -0.5138self.scale_pos_embed = scale_pos_embed139self.block_size = self.block_size_ds = block_size140self.halo_size = halo_size141self.win_size = block_size + halo_size * 2 # neighbourhood window size142self.block_stride = 1143use_avg_pool = False144if stride > 1:145use_avg_pool = avg_down or block_size % stride != 0146self.block_stride = 1 if use_avg_pool else stride147self.block_size_ds = self.block_size // self.block_stride148
149# FIXME not clear if this stride behaviour is what the paper intended150# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving151# data in unfolded block form. I haven't wrapped my head around how that'd look.152self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias)153self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)154
155self.pos_embed = PosEmbedRel(156block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale)157
158self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity()159
160self.reset_parameters()161
162def reset_parameters(self):163std = self.q.weight.shape[1] ** -0.5 # fan-in164trunc_normal_(self.q.weight, std=std)165trunc_normal_(self.kv.weight, std=std)166trunc_normal_(self.pos_embed.height_rel, std=self.scale)167trunc_normal_(self.pos_embed.width_rel, std=self.scale)168
169def forward(self, x):170B, C, H, W = x.shape171_assert(H % self.block_size == 0, '')172_assert(W % self.block_size == 0, '')173num_h_blocks = H // self.block_size174num_w_blocks = W // self.block_size175num_blocks = num_h_blocks * num_w_blocks176
177q = self.q(x)178# unfold179q = q.reshape(180-1, self.dim_head_qk,181num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4)182# B, num_heads * dim_head * block_size ** 2, num_blocks183q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)184# B * num_heads, num_blocks, block_size ** 2, dim_head185
186kv = self.kv(x)187# Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not188# lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach.189# FIXME figure out how to switch impl between this and conv2d if XLA being used.190kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])191kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(192B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)193k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)194# B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v195
196if self.scale_pos_embed:197attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale198else:199attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q)200# B * num_heads, num_blocks, block_size ** 2, win_size ** 2201attn = attn.softmax(dim=-1)202
203out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks204# fold205out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks)206out = out.permute(0, 3, 1, 4, 2).contiguous().view(207B, self.dim_out_v, H // self.block_stride, W // self.block_stride)208# B, dim_out, H // block_stride, W // block_stride209out = self.pool(out)210return out211
212
213""" Three alternatives for overlapping windows.
214
215`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()
216
217if is_xla:
218# This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is
219# EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment.
220WW = self.win_size ** 2
221pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size)
222kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size)
223elif self.stride_tricks:
224kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
225kv = kv.as_strided((
226B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
227stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
228else:
229kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
230
231kv = kv.reshape(
232B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)
233"""
234