pytorch-image-models

Форк
0
/
split_batchnorm.py 
75 строк · 3.4 Кб
1
""" Split BatchNorm
2

3
A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
4
a separate BN layer. The first split is passed through the parent BN layers with weight/bias
5
keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
6
namespace.
7

8
This allows easily removing the auxiliary BN layers after training to efficiently
9
achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
10
'Disentangled Learning via An Auxiliary BN'
11

12
Hacked together by / Copyright 2020 Ross Wightman
13
"""
14
import torch
15
import torch.nn as nn
16

17

18
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
19

20
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
21
                 track_running_stats=True, num_splits=2):
22
        super().__init__(num_features, eps, momentum, affine, track_running_stats)
23
        assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
24
        self.num_splits = num_splits
25
        self.aux_bn = nn.ModuleList([
26
            nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
27

28
    def forward(self, input: torch.Tensor):
29
        if self.training:  # aux BN only relevant while training
30
            split_size = input.shape[0] // self.num_splits
31
            assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
32
            split_input = input.split(split_size)
33
            x = [super().forward(split_input[0])]
34
            for i, a in enumerate(self.aux_bn):
35
                x.append(a(split_input[i + 1]))
36
            return torch.cat(x, dim=0)
37
        else:
38
            return super().forward(input)
39

40

41
def convert_splitbn_model(module, num_splits=2):
42
    """
43
    Recursively traverse module and its children to replace all instances of
44
    ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
45
    Args:
46
        module (torch.nn.Module): input module
47
        num_splits: number of separate batchnorm layers to split input across
48
    Example::
49
        >>> # model is an instance of torch.nn.Module
50
        >>> model = timm.models.convert_splitbn_model(model, num_splits=2)
51
    """
52
    mod = module
53
    if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
54
        return module
55
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
56
        mod = SplitBatchNorm2d(
57
            module.num_features, module.eps, module.momentum, module.affine,
58
            module.track_running_stats, num_splits=num_splits)
59
        mod.running_mean = module.running_mean
60
        mod.running_var = module.running_var
61
        mod.num_batches_tracked = module.num_batches_tracked
62
        if module.affine:
63
            mod.weight.data = module.weight.data.clone().detach()
64
            mod.bias.data = module.bias.data.clone().detach()
65
        for aux in mod.aux_bn:
66
            aux.running_mean = module.running_mean.clone()
67
            aux.running_var = module.running_var.clone()
68
            aux.num_batches_tracked = module.num_batches_tracked.clone()
69
            if module.affine:
70
                aux.weight.data = module.weight.data.clone().detach()
71
                aux.bias.data = module.bias.data.clone().detach()
72
    for name, child in module.named_children():
73
        mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
74
    del module
75
    return mod
76

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.