pytorch-image-models

Форк
0
465 строк · 17.0 Кб
1
""" Normalization + Activation Layers
2

3
Provides Norm+Act fns for standard PyTorch norm layers such as
4
* BatchNorm
5
* GroupNorm
6
* LayerNorm
7

8
This allows swapping with alternative layers that are natively both norm + act such as
9
* EvoNorm (evo_norm.py)
10
* FilterResponseNorm (filter_response_norm.py)
11
* InplaceABN (inplace_abn.py)
12

13
Hacked together by / Copyright 2022 Ross Wightman
14
"""
15
from typing import Union, List, Optional, Any
16

17
import torch
18
from torch import nn as nn
19
from torch.nn import functional as F
20
from torchvision.ops.misc import FrozenBatchNorm2d
21

22
from .create_act import get_act_layer
23
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
24
from .trace_utils import _assert
25

26

27
def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
28
    act_layer = get_act_layer(act_layer)  # string -> nn.Module
29
    act_kwargs = act_kwargs or {}
30
    if act_layer is not None and apply_act:
31
        if inplace:
32
            act_kwargs['inplace'] = inplace
33
        act = act_layer(**act_kwargs)
34
    else:
35
        act = nn.Identity()
36
    return act
37

38

39
class BatchNormAct2d(nn.BatchNorm2d):
40
    """BatchNorm + Activation
41

42
    This module performs BatchNorm + Activation in a manner that will remain backwards
43
    compatible with weights trained with separate bn, act. This is why we inherit from BN
44
    instead of composing it as a .bn member.
45
    """
46
    def __init__(
47
            self,
48
            num_features,
49
            eps=1e-5,
50
            momentum=0.1,
51
            affine=True,
52
            track_running_stats=True,
53
            apply_act=True,
54
            act_layer=nn.ReLU,
55
            act_kwargs=None,
56
            inplace=True,
57
            drop_layer=None,
58
            device=None,
59
            dtype=None,
60
    ):
61
        try:
62
            factory_kwargs = {'device': device, 'dtype': dtype}
63
            super(BatchNormAct2d, self).__init__(
64
                num_features,
65
                eps=eps,
66
                momentum=momentum,
67
                affine=affine,
68
                track_running_stats=track_running_stats,
69
                **factory_kwargs,
70
            )
71
        except TypeError:
72
            # NOTE for backwards compat with old PyTorch w/o factory device/dtype support
73
            super(BatchNormAct2d, self).__init__(
74
                num_features,
75
                eps=eps,
76
                momentum=momentum,
77
                affine=affine,
78
                track_running_stats=track_running_stats,
79
            )
80
        self.drop = drop_layer() if drop_layer is not None else nn.Identity()
81
        self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
82

83
    def forward(self, x):
84
        # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
85
        _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')
86

87
        # exponential_average_factor is set to self.momentum
88
        # (when it is available) only so that it gets updated
89
        # in ONNX graph when this node is exported to ONNX.
90
        if self.momentum is None:
91
            exponential_average_factor = 0.0
92
        else:
93
            exponential_average_factor = self.momentum
94

95
        if self.training and self.track_running_stats:
96
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
97
            if self.num_batches_tracked is not None:  # type: ignore[has-type]
98
                self.num_batches_tracked.add_(1)  # type: ignore[has-type]
99
                if self.momentum is None:  # use cumulative moving average
100
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
101
                else:  # use exponential moving average
102
                    exponential_average_factor = self.momentum
103

104
        r"""
105
        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
106
        Mini-batch stats are used in training mode, and in eval mode when buffers are None.
107
        """
108
        if self.training:
109
            bn_training = True
110
        else:
111
            bn_training = (self.running_mean is None) and (self.running_var is None)
112

113
        r"""
114
        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
115
        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
116
        used for normalization (i.e. in eval mode when buffers are not None).
117
        """
118
        x = F.batch_norm(
119
            x,
120
            # If buffers are not to be tracked, ensure that they won't be updated
121
            self.running_mean if not self.training or self.track_running_stats else None,
122
            self.running_var if not self.training or self.track_running_stats else None,
123
            self.weight,
124
            self.bias,
125
            bn_training,
126
            exponential_average_factor,
127
            self.eps,
128
        )
129
        x = self.drop(x)
130
        x = self.act(x)
131
        return x
132

133

134
class SyncBatchNormAct(nn.SyncBatchNorm):
135
    # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
136
    # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
137
    # but ONLY when used in conjunction with the timm conversion function below.
138
    # Do not create this module directly or use the PyTorch conversion function.
139
    def forward(self, x: torch.Tensor) -> torch.Tensor:
140
        x = super().forward(x)  # SyncBN doesn't work with torchscript anyways, so this is fine
141
        if hasattr(self, "drop"):
142
            x = self.drop(x)
143
        if hasattr(self, "act"):
144
            x = self.act(x)
145
        return x
146

147

148
def convert_sync_batchnorm(module, process_group=None):
149
    # convert both BatchNorm and BatchNormAct layers to Synchronized variants
150
    module_output = module
151
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
152
        if isinstance(module, BatchNormAct2d):
153
            # convert timm norm + act layer
154
            module_output = SyncBatchNormAct(
155
                module.num_features,
156
                module.eps,
157
                module.momentum,
158
                module.affine,
159
                module.track_running_stats,
160
                process_group=process_group,
161
            )
162
            # set act and drop attr from the original module
163
            module_output.act = module.act
164
            module_output.drop = module.drop
165
        else:
166
            # convert standard BatchNorm layers
167
            module_output = torch.nn.SyncBatchNorm(
168
                module.num_features,
169
                module.eps,
170
                module.momentum,
171
                module.affine,
172
                module.track_running_stats,
173
                process_group,
174
            )
175
        if module.affine:
176
            with torch.no_grad():
177
                module_output.weight = module.weight
178
                module_output.bias = module.bias
179
        module_output.running_mean = module.running_mean
180
        module_output.running_var = module.running_var
181
        module_output.num_batches_tracked = module.num_batches_tracked
182
        if hasattr(module, "qconfig"):
183
            module_output.qconfig = module.qconfig
184
    for name, child in module.named_children():
185
        module_output.add_module(name, convert_sync_batchnorm(child, process_group))
186
    del module
187
    return module_output
188

189

190
class FrozenBatchNormAct2d(torch.nn.Module):
191
    """
192
    BatchNormAct2d where the batch statistics and the affine parameters are fixed
193

194
    Args:
195
        num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
196
        eps (float): a value added to the denominator for numerical stability. Default: 1e-5
197
    """
198

199
    def __init__(
200
        self,
201
        num_features: int,
202
        eps: float = 1e-5,
203
        apply_act=True,
204
        act_layer=nn.ReLU,
205
        act_kwargs=None,
206
        inplace=True,
207
        drop_layer=None,
208
    ):
209
        super().__init__()
210
        self.eps = eps
211
        self.register_buffer("weight", torch.ones(num_features))
212
        self.register_buffer("bias", torch.zeros(num_features))
213
        self.register_buffer("running_mean", torch.zeros(num_features))
214
        self.register_buffer("running_var", torch.ones(num_features))
215

216
        self.drop = drop_layer() if drop_layer is not None else nn.Identity()
217
        self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
218

219
    def _load_from_state_dict(
220
        self,
221
        state_dict: dict,
222
        prefix: str,
223
        local_metadata: dict,
224
        strict: bool,
225
        missing_keys: List[str],
226
        unexpected_keys: List[str],
227
        error_msgs: List[str],
228
    ):
229
        num_batches_tracked_key = prefix + "num_batches_tracked"
230
        if num_batches_tracked_key in state_dict:
231
            del state_dict[num_batches_tracked_key]
232

233
        super()._load_from_state_dict(
234
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
235
        )
236

237
    def forward(self, x: torch.Tensor) -> torch.Tensor:
238
        # move reshapes to the beginning
239
        # to make it fuser-friendly
240
        w = self.weight.reshape(1, -1, 1, 1)
241
        b = self.bias.reshape(1, -1, 1, 1)
242
        rv = self.running_var.reshape(1, -1, 1, 1)
243
        rm = self.running_mean.reshape(1, -1, 1, 1)
244
        scale = w * (rv + self.eps).rsqrt()
245
        bias = b - rm * scale
246
        x = x * scale + bias
247
        x = self.act(self.drop(x))
248
        return x
249

250
    def __repr__(self) -> str:
251
        return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})"
252

253

254
def freeze_batch_norm_2d(module):
255
    """
256
    Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers
257
    of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively.
258

259
    Args:
260
        module (torch.nn.Module): Any PyTorch module.
261

262
    Returns:
263
        torch.nn.Module: Resulting module
264

265
    Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
266
    """
267
    res = module
268
    if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)):
269
        res = FrozenBatchNormAct2d(module.num_features)
270
        res.num_features = module.num_features
271
        res.affine = module.affine
272
        if module.affine:
273
            res.weight.data = module.weight.data.clone().detach()
274
            res.bias.data = module.bias.data.clone().detach()
275
        res.running_mean.data = module.running_mean.data
276
        res.running_var.data = module.running_var.data
277
        res.eps = module.eps
278
        res.drop = module.drop
279
        res.act = module.act
280
    elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
281
        res = FrozenBatchNorm2d(module.num_features)
282
        res.num_features = module.num_features
283
        res.affine = module.affine
284
        if module.affine:
285
            res.weight.data = module.weight.data.clone().detach()
286
            res.bias.data = module.bias.data.clone().detach()
287
        res.running_mean.data = module.running_mean.data
288
        res.running_var.data = module.running_var.data
289
        res.eps = module.eps
290
    else:
291
        for name, child in module.named_children():
292
            new_child = freeze_batch_norm_2d(child)
293
            if new_child is not child:
294
                res.add_module(name, new_child)
295
    return res
296

297

298
def unfreeze_batch_norm_2d(module):
299
    """
300
    Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
301
    of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
302
    recursively and submodules are converted in place.
303

304
    Args:
305
        module (torch.nn.Module): Any PyTorch module.
306

307
    Returns:
308
        torch.nn.Module: Resulting module
309

310
    Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
311
    """
312
    res = module
313
    if isinstance(module, FrozenBatchNormAct2d):
314
        res = BatchNormAct2d(module.num_features)
315
        if module.affine:
316
            res.weight.data = module.weight.data.clone().detach()
317
            res.bias.data = module.bias.data.clone().detach()
318
        res.running_mean.data = module.running_mean.data
319
        res.running_var.data = module.running_var.data
320
        res.eps = module.eps
321
        res.drop = module.drop
322
        res.act = module.act
323
    elif isinstance(module, FrozenBatchNorm2d):
324
        res = torch.nn.BatchNorm2d(module.num_features)
325
        if module.affine:
326
            res.weight.data = module.weight.data.clone().detach()
327
            res.bias.data = module.bias.data.clone().detach()
328
        res.running_mean.data = module.running_mean.data
329
        res.running_var.data = module.running_var.data
330
        res.eps = module.eps
331
    else:
332
        for name, child in module.named_children():
333
            new_child = unfreeze_batch_norm_2d(child)
334
            if new_child is not child:
335
                res.add_module(name, new_child)
336
    return res
337

338

339
def _num_groups(num_channels, num_groups, group_size):
340
    if group_size:
341
        assert num_channels % group_size == 0
342
        return num_channels // group_size
343
    return num_groups
344

345

346
class GroupNormAct(nn.GroupNorm):
347
    # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
348
    def __init__(
349
            self,
350
            num_channels,
351
            num_groups=32,
352
            eps=1e-5,
353
            affine=True,
354
            group_size=None,
355
            apply_act=True,
356
            act_layer=nn.ReLU,
357
            act_kwargs=None,
358
            inplace=True,
359
            drop_layer=None,
360
    ):
361
        super(GroupNormAct, self).__init__(
362
            _num_groups(num_channels, num_groups, group_size),
363
            num_channels,
364
            eps=eps,
365
            affine=affine,
366
        )
367
        self.drop = drop_layer() if drop_layer is not None else nn.Identity()
368
        self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
369

370
        self._fast_norm = is_fast_norm()
371

372
    def forward(self, x):
373
        if self._fast_norm:
374
            x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
375
        else:
376
            x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
377
        x = self.drop(x)
378
        x = self.act(x)
379
        return x
380

381

382
class GroupNorm1Act(nn.GroupNorm):
383
    def __init__(
384
            self,
385
            num_channels,
386
            eps=1e-5,
387
            affine=True,
388
            apply_act=True,
389
            act_layer=nn.ReLU,
390
            act_kwargs=None,
391
            inplace=True,
392
            drop_layer=None,
393
    ):
394
        super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine)
395
        self.drop = drop_layer() if drop_layer is not None else nn.Identity()
396
        self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
397

398
        self._fast_norm = is_fast_norm()
399

400
    def forward(self, x):
401
        if self._fast_norm:
402
            x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
403
        else:
404
            x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
405
        x = self.drop(x)
406
        x = self.act(x)
407
        return x
408

409

410
class LayerNormAct(nn.LayerNorm):
411
    def __init__(
412
            self,
413
            normalization_shape: Union[int, List[int], torch.Size],
414
            eps=1e-5,
415
            affine=True,
416
            apply_act=True,
417
            act_layer=nn.ReLU,
418
            act_kwargs=None,
419
            inplace=True,
420
            drop_layer=None,
421
    ):
422
        super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
423
        self.drop = drop_layer() if drop_layer is not None else nn.Identity()
424
        act_layer = get_act_layer(act_layer)  # string -> nn.Module
425
        self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
426

427
        self._fast_norm = is_fast_norm()
428

429
    def forward(self, x):
430
        if self._fast_norm:
431
            x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
432
        else:
433
            x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
434
        x = self.drop(x)
435
        x = self.act(x)
436
        return x
437

438

439
class LayerNormAct2d(nn.LayerNorm):
440
    def __init__(
441
            self,
442
            num_channels,
443
            eps=1e-5,
444
            affine=True,
445
            apply_act=True,
446
            act_layer=nn.ReLU,
447
            act_kwargs=None,
448
            inplace=True,
449
            drop_layer=None,
450
    ):
451
        super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
452
        self.drop = drop_layer() if drop_layer is not None else nn.Identity()
453
        self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
454
        self._fast_norm = is_fast_norm()
455

456
    def forward(self, x):
457
        x = x.permute(0, 2, 3, 1)
458
        if self._fast_norm:
459
            x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
460
        else:
461
            x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
462
        x = x.permute(0, 3, 1, 2)
463
        x = self.drop(x)
464
        x = self.act(x)
465
        return x
466

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.