pytorch-image-models
75 строк · 3.4 Кб
1""" Split BatchNorm
2
3A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
4a separate BN layer. The first split is passed through the parent BN layers with weight/bias
5keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
6namespace.
7
8This allows easily removing the auxiliary BN layers after training to efficiently
9achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
10'Disentangled Learning via An Auxiliary BN'
11
12Hacked together by / Copyright 2020 Ross Wightman
13"""
14import torch15import torch.nn as nn16
17
18class SplitBatchNorm2d(torch.nn.BatchNorm2d):19
20def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,21track_running_stats=True, num_splits=2):22super().__init__(num_features, eps, momentum, affine, track_running_stats)23assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'24self.num_splits = num_splits25self.aux_bn = nn.ModuleList([26nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])27
28def forward(self, input: torch.Tensor):29if self.training: # aux BN only relevant while training30split_size = input.shape[0] // self.num_splits31assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"32split_input = input.split(split_size)33x = [super().forward(split_input[0])]34for i, a in enumerate(self.aux_bn):35x.append(a(split_input[i + 1]))36return torch.cat(x, dim=0)37else:38return super().forward(input)39
40
41def convert_splitbn_model(module, num_splits=2):42"""43Recursively traverse module and its children to replace all instances of
44``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
45Args:
46module (torch.nn.Module): input module
47num_splits: number of separate batchnorm layers to split input across
48Example::
49>>> # model is an instance of torch.nn.Module
50>>> model = timm.models.convert_splitbn_model(model, num_splits=2)
51"""
52mod = module53if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):54return module55if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):56mod = SplitBatchNorm2d(57module.num_features, module.eps, module.momentum, module.affine,58module.track_running_stats, num_splits=num_splits)59mod.running_mean = module.running_mean60mod.running_var = module.running_var61mod.num_batches_tracked = module.num_batches_tracked62if module.affine:63mod.weight.data = module.weight.data.clone().detach()64mod.bias.data = module.bias.data.clone().detach()65for aux in mod.aux_bn:66aux.running_mean = module.running_mean.clone()67aux.running_var = module.running_var.clone()68aux.num_batches_tracked = module.num_batches_tracked.clone()69if module.affine:70aux.weight.data = module.weight.data.clone().detach()71aux.bias.data = module.bias.data.clone().detach()72for name, child in module.named_children():73mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))74del module75return mod76