pytorch-image-models
156 строк · 6.8 Кб
1from typing import Optional2
3import torch4from torch import nn5from torch import nn, Tensor6from torch.nn.modules.transformer import _get_activation_fn7
8
9def add_ml_decoder_head(model):10if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet5011model.global_pool = nn.Identity()12del model.fc13num_classes = model.num_classes14num_features = model.num_features15model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)16elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet17model.global_pool = nn.Identity()18del model.classifier19num_classes = model.num_classes20num_features = model.num_features21model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)22elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')23del model.head24num_classes = model.num_classes25num_features = model.num_features26model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)27else:28print("Model code-writing is not aligned currently with ml-decoder")29exit(-1)30if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout31model.drop_rate = 032return model33
34
35class TransformerDecoderLayerOptimal(nn.Module):36def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu",37layer_norm_eps=1e-5) -> None:38super(TransformerDecoderLayerOptimal, self).__init__()39self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)40self.dropout = nn.Dropout(dropout)41self.dropout1 = nn.Dropout(dropout)42self.dropout2 = nn.Dropout(dropout)43self.dropout3 = nn.Dropout(dropout)44
45self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)46
47# Implementation of Feedforward model48self.linear1 = nn.Linear(d_model, dim_feedforward)49self.linear2 = nn.Linear(dim_feedforward, d_model)50
51self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)52self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)53
54self.activation = _get_activation_fn(activation)55
56def __setstate__(self, state):57if 'activation' not in state:58state['activation'] = torch.nn.functional.relu59super(TransformerDecoderLayerOptimal, self).__setstate__(state)60
61def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,62memory_mask: Optional[Tensor] = None,63tgt_key_padding_mask: Optional[Tensor] = None,64memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:65tgt = tgt + self.dropout1(tgt)66tgt = self.norm1(tgt)67tgt2 = self.multihead_attn(tgt, memory, memory)[0]68tgt = tgt + self.dropout2(tgt2)69tgt = self.norm2(tgt)70tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))71tgt = tgt + self.dropout3(tgt2)72tgt = self.norm3(tgt)73return tgt74
75
76# @torch.jit.script
77# class ExtrapClasses(object):
78# def __init__(self, num_queries: int, group_size: int):
79# self.num_queries = num_queries
80# self.group_size = group_size
81#
82# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap:
83# torch.Tensor):
84# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size)
85# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups])
86# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size))
87# out = (h * w).sum(dim=2) + class_embed_b
88# out = out.view((h.shape[0], self.group_size * self.num_queries))
89# return out
90
91@torch.jit.script92class GroupFC(object):93def __init__(self, embed_len_decoder: int):94self.embed_len_decoder = embed_len_decoder95
96def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):97for i in range(self.embed_len_decoder):98h_i = h[:, i, :]99w_i = duplicate_pooling[i, :, :]100out_extrap[:, i, :] = torch.matmul(h_i, w_i)101
102
103class MLDecoder(nn.Module):104def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):105super(MLDecoder, self).__init__()106embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups107if embed_len_decoder > num_classes:108embed_len_decoder = num_classes109
110# switching to 768 initial embeddings111decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding112self.embed_standart = nn.Linear(initial_num_features, decoder_embedding)113
114# decoder115decoder_dropout = 0.1116num_layers_decoder = 1117dim_feedforward = 2048118layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding,119dim_feedforward=dim_feedforward, dropout=decoder_dropout)120self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder)121
122# non-learnable queries123self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding)124self.query_embed.requires_grad_(False)125
126# group fully-connected127self.num_classes = num_classes128self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999)129self.duplicate_pooling = torch.nn.Parameter(130torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor))131self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))132torch.nn.init.xavier_normal_(self.duplicate_pooling)133torch.nn.init.constant_(self.duplicate_pooling_bias, 0)134self.group_fc = GroupFC(embed_len_decoder)135
136def forward(self, x):137if len(x.shape) == 4: # [bs,2048, 7,7]138embedding_spatial = x.flatten(2).transpose(1, 2)139else: # [bs, 197,468]140embedding_spatial = x141embedding_spatial_786 = self.embed_standart(embedding_spatial)142embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)143
144bs = embedding_spatial_786.shape[0]145query_embed = self.query_embed.weight146# tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)147tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand148h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768]149h = h.transpose(0, 1)150
151out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)152self.group_fc(h, self.duplicate_pooling, out_extrap)153h_out = out_extrap.flatten(1)[:, :self.num_classes]154h_out += self.duplicate_pooling_bias155logits = h_out156return logits157