pytorch-image-models

Форк
0
/
attention_pool2d.py 
131 строка · 4.8 Кб
1
""" Attention Pool 2D
2

3
Implementations of 2D spatial feature pooling using multi-head attention instead of average pool.
4

5
Based on idea in CLIP by OpenAI, licensed Apache 2.0
6
https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
7

8
Hacked together by / Copyright 2021 Ross Wightman
9
"""
10
from typing import Union, Tuple
11

12
import torch
13
import torch.nn as nn
14

15
from .helpers import to_2tuple
16
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
17
from .weight_init import trunc_normal_
18

19

20
class RotAttentionPool2d(nn.Module):
21
    """ Attention based 2D feature pooling w/ rotary (relative) pos embedding.
22
    This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
23

24
    Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.
25
    https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
26

27
    NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
28
    train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
29
    """
30
    def __init__(
31
            self,
32
            in_features: int,
33
            out_features: int = None,
34
            embed_dim: int = None,
35
            num_heads: int = 4,
36
            qkv_bias: bool = True,
37
    ):
38
        super().__init__()
39
        embed_dim = embed_dim or in_features
40
        out_features = out_features or in_features
41
        self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
42
        self.proj = nn.Linear(embed_dim, out_features)
43
        self.num_heads = num_heads
44
        assert embed_dim % num_heads == 0
45
        self.head_dim = embed_dim // num_heads
46
        self.scale = self.head_dim ** -0.5
47
        self.pos_embed = RotaryEmbedding(self.head_dim)
48

49
        trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
50
        nn.init.zeros_(self.qkv.bias)
51

52
    def forward(self, x):
53
        B, _, H, W = x.shape
54
        N = H * W
55
        x = x.reshape(B, -1, N).permute(0, 2, 1)
56

57
        x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
58

59
        x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
60
        q, k, v = x[0], x[1], x[2]
61

62
        qc, q = q[:, :, :1], q[:, :, 1:]
63
        sin_emb, cos_emb = self.pos_embed.get_embed((H, W))
64
        q = apply_rot_embed(q, sin_emb, cos_emb)
65
        q = torch.cat([qc, q], dim=2)
66

67
        kc, k = k[:, :, :1], k[:, :, 1:]
68
        k = apply_rot_embed(k, sin_emb, cos_emb)
69
        k = torch.cat([kc, k], dim=2)
70

71
        attn = (q @ k.transpose(-2, -1)) * self.scale
72
        attn = attn.softmax(dim=-1)
73

74
        x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
75
        x = self.proj(x)
76
        return x[:, 0]
77

78

79
class AttentionPool2d(nn.Module):
80
    """ Attention based 2D feature pooling w/ learned (absolute) pos embedding.
81
    This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
82

83
    It was based on impl in CLIP by OpenAI
84
    https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
85

86
    NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
87
    """
88
    def __init__(
89
            self,
90
            in_features: int,
91
            feat_size: Union[int, Tuple[int, int]],
92
            out_features: int = None,
93
            embed_dim: int = None,
94
            num_heads: int = 4,
95
            qkv_bias: bool = True,
96
    ):
97
        super().__init__()
98

99
        embed_dim = embed_dim or in_features
100
        out_features = out_features or in_features
101
        assert embed_dim % num_heads == 0
102
        self.feat_size = to_2tuple(feat_size)
103
        self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
104
        self.proj = nn.Linear(embed_dim, out_features)
105
        self.num_heads = num_heads
106
        self.head_dim = embed_dim // num_heads
107
        self.scale = self.head_dim ** -0.5
108

109
        spatial_dim = self.feat_size[0] * self.feat_size[1]
110
        self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
111
        trunc_normal_(self.pos_embed, std=in_features ** -0.5)
112
        trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
113
        nn.init.zeros_(self.qkv.bias)
114

115
    def forward(self, x):
116
        B, _, H, W = x.shape
117
        N = H * W
118
        assert self.feat_size[0] == H
119
        assert self.feat_size[1] == W
120
        x = x.reshape(B, -1, N).permute(0, 2, 1)
121
        x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
122
        x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
123

124
        x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
125
        q, k, v = x[0], x[1], x[2]
126
        attn = (q @ k.transpose(-2, -1)) * self.scale
127
        attn = attn.softmax(dim=-1)
128

129
        x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
130
        x = self.proj(x)
131
        return x[:, 0]
132

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.