pytorch-image-models

Форк
0
182 строки · 6.8 Кб
1
""" DropBlock, DropPath
2

3
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
4

5
Papers:
6
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
7

8
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
9

10
Code:
11
DropBlock impl inspired by two Tensorflow impl that I liked:
12
 - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
13
 - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
14

15
Hacked together by / Copyright 2020 Ross Wightman
16
"""
17
import torch
18
import torch.nn as nn
19
import torch.nn.functional as F
20

21
from .grid import ndgrid
22

23

24
def drop_block_2d(
25
        x,
26
        drop_prob: float = 0.1,
27
        block_size: int = 7,
28
        gamma_scale: float = 1.0,
29
        with_noise: bool = False,
30
        inplace: bool = False,
31
        batchwise: bool = False
32
):
33
    """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
34

35
    DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
36
    runs with success, but needs further validation and possibly optimization for lower runtime impact.
37
    """
38
    B, C, H, W = x.shape
39
    total_size = W * H
40
    clipped_block_size = min(block_size, min(W, H))
41
    # seed_drop_rate, the gamma parameter
42
    gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
43
            (W - block_size + 1) * (H - block_size + 1))
44

45
    # Forces the block to be inside the feature map.
46
    w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device))
47
    valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
48
                  ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
49
    valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
50

51
    if batchwise:
52
        # one mask for whole batch, quite a bit faster
53
        uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
54
    else:
55
        uniform_noise = torch.rand_like(x)
56
    block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
57
    block_mask = -F.max_pool2d(
58
        -block_mask,
59
        kernel_size=clipped_block_size,  # block_size,
60
        stride=1,
61
        padding=clipped_block_size // 2)
62

63
    if with_noise:
64
        normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
65
        if inplace:
66
            x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
67
        else:
68
            x = x * block_mask + normal_noise * (1 - block_mask)
69
    else:
70
        normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
71
        if inplace:
72
            x.mul_(block_mask * normalize_scale)
73
        else:
74
            x = x * block_mask * normalize_scale
75
    return x
76

77

78
def drop_block_fast_2d(
79
        x: torch.Tensor,
80
        drop_prob: float = 0.1,
81
        block_size: int = 7,
82
        gamma_scale: float = 1.0,
83
        with_noise: bool = False,
84
        inplace: bool = False,
85
):
86
    """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
87

88
    DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
89
    block mask at edges.
90
    """
91
    B, C, H, W = x.shape
92
    total_size = W * H
93
    clipped_block_size = min(block_size, min(W, H))
94
    gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
95
            (W - block_size + 1) * (H - block_size + 1))
96

97
    block_mask = torch.empty_like(x).bernoulli_(gamma)
98
    block_mask = F.max_pool2d(
99
        block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
100

101
    if with_noise:
102
        normal_noise = torch.empty_like(x).normal_()
103
        if inplace:
104
            x.mul_(1. - block_mask).add_(normal_noise * block_mask)
105
        else:
106
            x = x * (1. - block_mask) + normal_noise * block_mask
107
    else:
108
        block_mask = 1 - block_mask
109
        normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
110
        if inplace:
111
            x.mul_(block_mask * normalize_scale)
112
        else:
113
            x = x * block_mask * normalize_scale
114
    return x
115

116

117
class DropBlock2d(nn.Module):
118
    """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
119
    """
120

121
    def __init__(
122
            self,
123
            drop_prob: float = 0.1,
124
            block_size: int = 7,
125
            gamma_scale: float = 1.0,
126
            with_noise: bool = False,
127
            inplace: bool = False,
128
            batchwise: bool = False,
129
            fast: bool = True):
130
        super(DropBlock2d, self).__init__()
131
        self.drop_prob = drop_prob
132
        self.gamma_scale = gamma_scale
133
        self.block_size = block_size
134
        self.with_noise = with_noise
135
        self.inplace = inplace
136
        self.batchwise = batchwise
137
        self.fast = fast  # FIXME finish comparisons of fast vs not
138

139
    def forward(self, x):
140
        if not self.training or not self.drop_prob:
141
            return x
142
        if self.fast:
143
            return drop_block_fast_2d(
144
                x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace)
145
        else:
146
            return drop_block_2d(
147
                x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
148

149

150
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
151
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
152

153
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
154
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
155
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
156
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
157
    'survival rate' as the argument.
158

159
    """
160
    if drop_prob == 0. or not training:
161
        return x
162
    keep_prob = 1 - drop_prob
163
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
164
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
165
    if keep_prob > 0.0 and scale_by_keep:
166
        random_tensor.div_(keep_prob)
167
    return x * random_tensor
168

169

170
class DropPath(nn.Module):
171
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
172
    """
173
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
174
        super(DropPath, self).__init__()
175
        self.drop_prob = drop_prob
176
        self.scale_by_keep = scale_by_keep
177

178
    def forward(self, x):
179
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
180

181
    def extra_repr(self):
182
        return f'drop_prob={round(self.drop_prob,3):0.3f}'
183

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.