pytorch-image-models

Форк
0
208 строк · 7.3 Кб
1
""" Classifier head and layer factory
2

3
Hacked together by / Copyright 2020 Ross Wightman
4
"""
5
from collections import OrderedDict
6
from functools import partial
7
from typing import Optional, Union, Callable
8

9
import torch
10
import torch.nn as nn
11
from torch.nn import functional as F
12

13
from .adaptive_avgmax_pool import SelectAdaptivePool2d
14
from .create_act import get_act_layer
15
from .create_norm import get_norm_layer
16

17

18
def _create_pool(
19
        num_features: int,
20
        num_classes: int,
21
        pool_type: str = 'avg',
22
        use_conv: bool = False,
23
        input_fmt: Optional[str] = None,
24
):
25
    flatten_in_pool = not use_conv  # flatten when we use a Linear layer after pooling
26
    if not pool_type:
27
        assert num_classes == 0 or use_conv,\
28
            'Pooling can only be disabled if classifier is also removed or conv classifier is used'
29
        flatten_in_pool = False  # disable flattening if pooling is pass-through (no pooling)
30
    global_pool = SelectAdaptivePool2d(
31
        pool_type=pool_type,
32
        flatten=flatten_in_pool,
33
        input_fmt=input_fmt,
34
    )
35
    num_pooled_features = num_features * global_pool.feat_mult()
36
    return global_pool, num_pooled_features
37

38

39
def _create_fc(num_features, num_classes, use_conv=False):
40
    if num_classes <= 0:
41
        fc = nn.Identity()  # pass-through (no classifier)
42
    elif use_conv:
43
        fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
44
    else:
45
        fc = nn.Linear(num_features, num_classes, bias=True)
46
    return fc
47

48

49
def create_classifier(
50
        num_features: int,
51
        num_classes: int,
52
        pool_type: str = 'avg',
53
        use_conv: bool = False,
54
        input_fmt: str = 'NCHW',
55
        drop_rate: Optional[float] = None,
56
):
57
    global_pool, num_pooled_features = _create_pool(
58
        num_features,
59
        num_classes,
60
        pool_type,
61
        use_conv=use_conv,
62
        input_fmt=input_fmt,
63
    )
64
    fc = _create_fc(
65
        num_pooled_features,
66
        num_classes,
67
        use_conv=use_conv,
68
    )
69
    if drop_rate is not None:
70
        dropout = nn.Dropout(drop_rate)
71
        return global_pool, dropout, fc
72
    return global_pool, fc
73

74

75
class ClassifierHead(nn.Module):
76
    """Classifier head w/ configurable global pooling and dropout."""
77

78
    def __init__(
79
            self,
80
            in_features: int,
81
            num_classes: int,
82
            pool_type: str = 'avg',
83
            drop_rate: float = 0.,
84
            use_conv: bool = False,
85
            input_fmt: str = 'NCHW',
86
    ):
87
        """
88
        Args:
89
            in_features: The number of input features.
90
            num_classes:  The number of classes for the final classifier layer (output).
91
            pool_type: Global pooling type, pooling disabled if empty string ('').
92
            drop_rate: Pre-classifier dropout rate.
93
        """
94
        super(ClassifierHead, self).__init__()
95
        self.in_features = in_features
96
        self.use_conv = use_conv
97
        self.input_fmt = input_fmt
98

99
        global_pool, fc = create_classifier(
100
            in_features,
101
            num_classes,
102
            pool_type,
103
            use_conv=use_conv,
104
            input_fmt=input_fmt,
105
        )
106
        self.global_pool = global_pool
107
        self.drop = nn.Dropout(drop_rate)
108
        self.fc = fc
109
        self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
110

111
    def reset(self, num_classes, pool_type=None):
112
        if pool_type is not None and pool_type != self.global_pool.pool_type:
113
            self.global_pool, self.fc = create_classifier(
114
                self.in_features,
115
                num_classes,
116
                pool_type=pool_type,
117
                use_conv=self.use_conv,
118
                input_fmt=self.input_fmt,
119
            )
120
            self.flatten = nn.Flatten(1) if self.use_conv and pool_type else nn.Identity()
121
        else:
122
            num_pooled_features = self.in_features * self.global_pool.feat_mult()
123
            self.fc = _create_fc(
124
                num_pooled_features,
125
                num_classes,
126
                use_conv=self.use_conv,
127
            )
128

129
    def forward(self, x, pre_logits: bool = False):
130
        x = self.global_pool(x)
131
        x = self.drop(x)
132
        if pre_logits:
133
            return self.flatten(x)
134
        x = self.fc(x)
135
        return self.flatten(x)
136

137

138
class NormMlpClassifierHead(nn.Module):
139

140
    def __init__(
141
            self,
142
            in_features: int,
143
            num_classes: int,
144
            hidden_size: Optional[int] = None,
145
            pool_type: str = 'avg',
146
            drop_rate: float = 0.,
147
            norm_layer: Union[str, Callable] = 'layernorm2d',
148
            act_layer: Union[str, Callable] = 'tanh',
149
    ):
150
        """
151
        Args:
152
            in_features: The number of input features.
153
            num_classes:  The number of classes for the final classifier layer (output).
154
            hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
155
            pool_type: Global pooling type, pooling disabled if empty string ('').
156
            drop_rate: Pre-classifier dropout rate.
157
            norm_layer: Normalization layer type.
158
            act_layer: MLP activation layer type (only used if hidden_size is not None).
159
        """
160
        super().__init__()
161
        self.in_features = in_features
162
        self.hidden_size = hidden_size
163
        self.num_features = in_features
164
        self.use_conv = not pool_type
165
        norm_layer = get_norm_layer(norm_layer)
166
        act_layer = get_act_layer(act_layer)
167
        linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
168

169
        self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
170
        self.norm = norm_layer(in_features)
171
        self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
172
        if hidden_size:
173
            self.pre_logits = nn.Sequential(OrderedDict([
174
                ('fc', linear_layer(in_features, hidden_size)),
175
                ('act', act_layer()),
176
            ]))
177
            self.num_features = hidden_size
178
        else:
179
            self.pre_logits = nn.Identity()
180
        self.drop = nn.Dropout(drop_rate)
181
        self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
182

183
    def reset(self, num_classes, pool_type=None):
184
        if pool_type is not None:
185
            self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
186
            self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
187
        self.use_conv = self.global_pool.is_identity()
188
        linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
189
        if self.hidden_size:
190
            if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
191
                    (isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
192
                with torch.no_grad():
193
                    new_fc = linear_layer(self.in_features, self.hidden_size)
194
                    new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
195
                    new_fc.bias.copy_(self.pre_logits.fc.bias)
196
                    self.pre_logits.fc = new_fc
197
        self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
198

199
    def forward(self, x, pre_logits: bool = False):
200
        x = self.global_pool(x)
201
        x = self.norm(x)
202
        x = self.flatten(x)
203
        x = self.pre_logits(x)
204
        x = self.drop(x)
205
        if pre_logits:
206
            return x
207
        x = self.fc(x)
208
        return x
209

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.