pytorch-image-models
182 строки · 6.8 Кб
1""" DropBlock, DropPath
2
3PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
4
5Papers:
6DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
7
8Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
9
10Code:
11DropBlock impl inspired by two Tensorflow impl that I liked:
12- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
13- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
14
15Hacked together by / Copyright 2020 Ross Wightman
16"""
17import torch18import torch.nn as nn19import torch.nn.functional as F20
21from .grid import ndgrid22
23
24def drop_block_2d(25x,26drop_prob: float = 0.1,27block_size: int = 7,28gamma_scale: float = 1.0,29with_noise: bool = False,30inplace: bool = False,31batchwise: bool = False32):33""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf34
35DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
36runs with success, but needs further validation and possibly optimization for lower runtime impact.
37"""
38B, C, H, W = x.shape39total_size = W * H40clipped_block_size = min(block_size, min(W, H))41# seed_drop_rate, the gamma parameter42gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (43(W - block_size + 1) * (H - block_size + 1))44
45# Forces the block to be inside the feature map.46w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device))47valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \48((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))49valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)50
51if batchwise:52# one mask for whole batch, quite a bit faster53uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)54else:55uniform_noise = torch.rand_like(x)56block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)57block_mask = -F.max_pool2d(58-block_mask,59kernel_size=clipped_block_size, # block_size,60stride=1,61padding=clipped_block_size // 2)62
63if with_noise:64normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)65if inplace:66x.mul_(block_mask).add_(normal_noise * (1 - block_mask))67else:68x = x * block_mask + normal_noise * (1 - block_mask)69else:70normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)71if inplace:72x.mul_(block_mask * normalize_scale)73else:74x = x * block_mask * normalize_scale75return x76
77
78def drop_block_fast_2d(79x: torch.Tensor,80drop_prob: float = 0.1,81block_size: int = 7,82gamma_scale: float = 1.0,83with_noise: bool = False,84inplace: bool = False,85):86""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf87
88DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
89block mask at edges.
90"""
91B, C, H, W = x.shape92total_size = W * H93clipped_block_size = min(block_size, min(W, H))94gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (95(W - block_size + 1) * (H - block_size + 1))96
97block_mask = torch.empty_like(x).bernoulli_(gamma)98block_mask = F.max_pool2d(99block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)100
101if with_noise:102normal_noise = torch.empty_like(x).normal_()103if inplace:104x.mul_(1. - block_mask).add_(normal_noise * block_mask)105else:106x = x * (1. - block_mask) + normal_noise * block_mask107else:108block_mask = 1 - block_mask109normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)110if inplace:111x.mul_(block_mask * normalize_scale)112else:113x = x * block_mask * normalize_scale114return x115
116
117class DropBlock2d(nn.Module):118""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf119"""
120
121def __init__(122self,123drop_prob: float = 0.1,124block_size: int = 7,125gamma_scale: float = 1.0,126with_noise: bool = False,127inplace: bool = False,128batchwise: bool = False,129fast: bool = True):130super(DropBlock2d, self).__init__()131self.drop_prob = drop_prob132self.gamma_scale = gamma_scale133self.block_size = block_size134self.with_noise = with_noise135self.inplace = inplace136self.batchwise = batchwise137self.fast = fast # FIXME finish comparisons of fast vs not138
139def forward(self, x):140if not self.training or not self.drop_prob:141return x142if self.fast:143return drop_block_fast_2d(144x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace)145else:146return drop_block_2d(147x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)148
149
150def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):151"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).152
153This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
154the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
155See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
156changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
157'survival rate' as the argument.
158
159"""
160if drop_prob == 0. or not training:161return x162keep_prob = 1 - drop_prob163shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets164random_tensor = x.new_empty(shape).bernoulli_(keep_prob)165if keep_prob > 0.0 and scale_by_keep:166random_tensor.div_(keep_prob)167return x * random_tensor168
169
170class DropPath(nn.Module):171"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).172"""
173def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):174super(DropPath, self).__init__()175self.drop_prob = drop_prob176self.scale_by_keep = scale_by_keep177
178def forward(self, x):179return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)180
181def extra_repr(self):182return f'drop_prob={round(self.drop_prob,3):0.3f}'183