pytorch-image-models
134 строки · 5.8 Кб
1""" Lambda Layer
2
3Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
4- https://arxiv.org/abs/2102.08602
5
6@misc{2102.08602,
7Author = {Irwan Bello},
8Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention},
9Year = {2021},
10}
11
12Status:
13This impl is a WIP. Code snippets in the paper were used as reference but
14good chance some details are missing/wrong.
15
16I've only implemented local lambda conv based pos embeddings.
17
18For a PyTorch impl that includes other embedding options checkout
19https://github.com/lucidrains/lambda-networks
20
21Hacked together by / Copyright 2021 Ross Wightman
22"""
23import torch
24from torch import nn
25import torch.nn.functional as F
26
27from .grid import ndgrid
28from .helpers import to_2tuple, make_divisible
29from .weight_init import trunc_normal_
30
31
32def rel_pos_indices(size):
33size = to_2tuple(size)
34pos = torch.stack(ndgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
35rel_pos = pos[:, None, :] - pos[:, :, None]
36rel_pos[0] += size[0] - 1
37rel_pos[1] += size[1] - 1
38return rel_pos # 2, H * W, H * W
39
40
41class LambdaLayer(nn.Module):
42"""Lambda Layer
43
44Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
45- https://arxiv.org/abs/2102.08602
46
47NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
48
49The internal dimensions of the lambda module are controlled via the interaction of several arguments.
50* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
51* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
52* the query (q) and key (k) dimension are determined by
53* dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None
54* q = num_heads * dim_head, k = dim_head
55* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set
56
57Args:
58dim (int): input dimension to the module
59dim_out (int): output dimension of the module, same as dim if not set
60feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W
61stride (int): output stride of the module, avg pool used if stride == 2
62num_heads (int): parallel attention heads.
63dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
64r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)
65qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
66qkv_bias (bool): add bias to q, k, and v projections
67"""
68def __init__(
69self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9,
70qk_ratio=1.0, qkv_bias=False):
71super().__init__()
72dim_out = dim_out or dim
73assert dim_out % num_heads == 0, ' should be divided by num_heads'
74self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
75self.num_heads = num_heads
76self.dim_v = dim_out // num_heads
77
78self.qkv = nn.Conv2d(
79dim,
80num_heads * self.dim_qk + self.dim_qk + self.dim_v,
81kernel_size=1, bias=qkv_bias)
82self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk)
83self.norm_v = nn.BatchNorm2d(self.dim_v)
84
85if r is not None:
86# local lambda convolution for pos
87self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0))
88self.pos_emb = None
89self.rel_pos_indices = None
90else:
91# relative pos embedding
92assert feat_size is not None
93feat_size = to_2tuple(feat_size)
94rel_size = [2 * s - 1 for s in feat_size]
95self.conv_lambda = None
96self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk))
97self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False)
98
99self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
100
101self.reset_parameters()
102
103def reset_parameters(self):
104trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
105if self.conv_lambda is not None:
106trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5)
107if self.pos_emb is not None:
108trunc_normal_(self.pos_emb, std=.02)
109
110def forward(self, x):
111B, C, H, W = x.shape
112M = H * W
113qkv = self.qkv(x)
114q, k, v = torch.split(qkv, [
115self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1)
116q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K
117v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
118k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
119
120content_lam = k @ v # B, K, V
121content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
122
123if self.pos_emb is None:
124position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
125position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
126else:
127# FIXME relative pos embedding path not fully verified
128pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)
129position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V
130position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V
131
132out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
133out = self.pool(out)
134return out
135