pytorch-image-models
131 строка · 4.8 Кб
1""" Attention Pool 2D
2
3Implementations of 2D spatial feature pooling using multi-head attention instead of average pool.
4
5Based on idea in CLIP by OpenAI, licensed Apache 2.0
6https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
7
8Hacked together by / Copyright 2021 Ross Wightman
9"""
10from typing import Union, Tuple11
12import torch13import torch.nn as nn14
15from .helpers import to_2tuple16from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding17from .weight_init import trunc_normal_18
19
20class RotAttentionPool2d(nn.Module):21""" Attention based 2D feature pooling w/ rotary (relative) pos embedding.22This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
23
24Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.
25https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
26
27NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
28train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
29"""
30def __init__(31self,32in_features: int,33out_features: int = None,34embed_dim: int = None,35num_heads: int = 4,36qkv_bias: bool = True,37):38super().__init__()39embed_dim = embed_dim or in_features40out_features = out_features or in_features41self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)42self.proj = nn.Linear(embed_dim, out_features)43self.num_heads = num_heads44assert embed_dim % num_heads == 045self.head_dim = embed_dim // num_heads46self.scale = self.head_dim ** -0.547self.pos_embed = RotaryEmbedding(self.head_dim)48
49trunc_normal_(self.qkv.weight, std=in_features ** -0.5)50nn.init.zeros_(self.qkv.bias)51
52def forward(self, x):53B, _, H, W = x.shape54N = H * W55x = x.reshape(B, -1, N).permute(0, 2, 1)56
57x = torch.cat([x.mean(1, keepdim=True), x], dim=1)58
59x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)60q, k, v = x[0], x[1], x[2]61
62qc, q = q[:, :, :1], q[:, :, 1:]63sin_emb, cos_emb = self.pos_embed.get_embed((H, W))64q = apply_rot_embed(q, sin_emb, cos_emb)65q = torch.cat([qc, q], dim=2)66
67kc, k = k[:, :, :1], k[:, :, 1:]68k = apply_rot_embed(k, sin_emb, cos_emb)69k = torch.cat([kc, k], dim=2)70
71attn = (q @ k.transpose(-2, -1)) * self.scale72attn = attn.softmax(dim=-1)73
74x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)75x = self.proj(x)76return x[:, 0]77
78
79class AttentionPool2d(nn.Module):80""" Attention based 2D feature pooling w/ learned (absolute) pos embedding.81This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
82
83It was based on impl in CLIP by OpenAI
84https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
85
86NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
87"""
88def __init__(89self,90in_features: int,91feat_size: Union[int, Tuple[int, int]],92out_features: int = None,93embed_dim: int = None,94num_heads: int = 4,95qkv_bias: bool = True,96):97super().__init__()98
99embed_dim = embed_dim or in_features100out_features = out_features or in_features101assert embed_dim % num_heads == 0102self.feat_size = to_2tuple(feat_size)103self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)104self.proj = nn.Linear(embed_dim, out_features)105self.num_heads = num_heads106self.head_dim = embed_dim // num_heads107self.scale = self.head_dim ** -0.5108
109spatial_dim = self.feat_size[0] * self.feat_size[1]110self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))111trunc_normal_(self.pos_embed, std=in_features ** -0.5)112trunc_normal_(self.qkv.weight, std=in_features ** -0.5)113nn.init.zeros_(self.qkv.bias)114
115def forward(self, x):116B, _, H, W = x.shape117N = H * W118assert self.feat_size[0] == H119assert self.feat_size[1] == W120x = x.reshape(B, -1, N).permute(0, 2, 1)121x = torch.cat([x.mean(1, keepdim=True), x], dim=1)122x = x + self.pos_embed.unsqueeze(0).to(x.dtype)123
124x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)125q, k, v = x[0], x[1], x[2]126attn = (q @ k.transpose(-2, -1)) * self.scale127attn = attn.softmax(dim=-1)128
129x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)130x = self.proj(x)131return x[:, 0]132