pytorch-image-models
208 строк · 7.3 Кб
1""" Classifier head and layer factory
2
3Hacked together by / Copyright 2020 Ross Wightman
4"""
5from collections import OrderedDict6from functools import partial7from typing import Optional, Union, Callable8
9import torch10import torch.nn as nn11from torch.nn import functional as F12
13from .adaptive_avgmax_pool import SelectAdaptivePool2d14from .create_act import get_act_layer15from .create_norm import get_norm_layer16
17
18def _create_pool(19num_features: int,20num_classes: int,21pool_type: str = 'avg',22use_conv: bool = False,23input_fmt: Optional[str] = None,24):25flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling26if not pool_type:27assert num_classes == 0 or use_conv,\28'Pooling can only be disabled if classifier is also removed or conv classifier is used'29flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)30global_pool = SelectAdaptivePool2d(31pool_type=pool_type,32flatten=flatten_in_pool,33input_fmt=input_fmt,34)35num_pooled_features = num_features * global_pool.feat_mult()36return global_pool, num_pooled_features37
38
39def _create_fc(num_features, num_classes, use_conv=False):40if num_classes <= 0:41fc = nn.Identity() # pass-through (no classifier)42elif use_conv:43fc = nn.Conv2d(num_features, num_classes, 1, bias=True)44else:45fc = nn.Linear(num_features, num_classes, bias=True)46return fc47
48
49def create_classifier(50num_features: int,51num_classes: int,52pool_type: str = 'avg',53use_conv: bool = False,54input_fmt: str = 'NCHW',55drop_rate: Optional[float] = None,56):57global_pool, num_pooled_features = _create_pool(58num_features,59num_classes,60pool_type,61use_conv=use_conv,62input_fmt=input_fmt,63)64fc = _create_fc(65num_pooled_features,66num_classes,67use_conv=use_conv,68)69if drop_rate is not None:70dropout = nn.Dropout(drop_rate)71return global_pool, dropout, fc72return global_pool, fc73
74
75class ClassifierHead(nn.Module):76"""Classifier head w/ configurable global pooling and dropout."""77
78def __init__(79self,80in_features: int,81num_classes: int,82pool_type: str = 'avg',83drop_rate: float = 0.,84use_conv: bool = False,85input_fmt: str = 'NCHW',86):87"""88Args:
89in_features: The number of input features.
90num_classes: The number of classes for the final classifier layer (output).
91pool_type: Global pooling type, pooling disabled if empty string ('').
92drop_rate: Pre-classifier dropout rate.
93"""
94super(ClassifierHead, self).__init__()95self.in_features = in_features96self.use_conv = use_conv97self.input_fmt = input_fmt98
99global_pool, fc = create_classifier(100in_features,101num_classes,102pool_type,103use_conv=use_conv,104input_fmt=input_fmt,105)106self.global_pool = global_pool107self.drop = nn.Dropout(drop_rate)108self.fc = fc109self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()110
111def reset(self, num_classes, pool_type=None):112if pool_type is not None and pool_type != self.global_pool.pool_type:113self.global_pool, self.fc = create_classifier(114self.in_features,115num_classes,116pool_type=pool_type,117use_conv=self.use_conv,118input_fmt=self.input_fmt,119)120self.flatten = nn.Flatten(1) if self.use_conv and pool_type else nn.Identity()121else:122num_pooled_features = self.in_features * self.global_pool.feat_mult()123self.fc = _create_fc(124num_pooled_features,125num_classes,126use_conv=self.use_conv,127)128
129def forward(self, x, pre_logits: bool = False):130x = self.global_pool(x)131x = self.drop(x)132if pre_logits:133return self.flatten(x)134x = self.fc(x)135return self.flatten(x)136
137
138class NormMlpClassifierHead(nn.Module):139
140def __init__(141self,142in_features: int,143num_classes: int,144hidden_size: Optional[int] = None,145pool_type: str = 'avg',146drop_rate: float = 0.,147norm_layer: Union[str, Callable] = 'layernorm2d',148act_layer: Union[str, Callable] = 'tanh',149):150"""151Args:
152in_features: The number of input features.
153num_classes: The number of classes for the final classifier layer (output).
154hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
155pool_type: Global pooling type, pooling disabled if empty string ('').
156drop_rate: Pre-classifier dropout rate.
157norm_layer: Normalization layer type.
158act_layer: MLP activation layer type (only used if hidden_size is not None).
159"""
160super().__init__()161self.in_features = in_features162self.hidden_size = hidden_size163self.num_features = in_features164self.use_conv = not pool_type165norm_layer = get_norm_layer(norm_layer)166act_layer = get_act_layer(act_layer)167linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear168
169self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)170self.norm = norm_layer(in_features)171self.flatten = nn.Flatten(1) if pool_type else nn.Identity()172if hidden_size:173self.pre_logits = nn.Sequential(OrderedDict([174('fc', linear_layer(in_features, hidden_size)),175('act', act_layer()),176]))177self.num_features = hidden_size178else:179self.pre_logits = nn.Identity()180self.drop = nn.Dropout(drop_rate)181self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()182
183def reset(self, num_classes, pool_type=None):184if pool_type is not None:185self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)186self.flatten = nn.Flatten(1) if pool_type else nn.Identity()187self.use_conv = self.global_pool.is_identity()188linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear189if self.hidden_size:190if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or191(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):192with torch.no_grad():193new_fc = linear_layer(self.in_features, self.hidden_size)194new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))195new_fc.bias.copy_(self.pre_logits.fc.bias)196self.pre_logits.fc = new_fc197self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()198
199def forward(self, x, pre_logits: bool = False):200x = self.global_pool(x)201x = self.norm(x)202x = self.flatten(x)203x = self.pre_logits(x)204x = self.drop(x)205if pre_logits:206return x207x = self.fc(x)208return x209