pytorch-image-models
49 строк · 1.7 Кб
1""" Median Pool
2Hacked together by / Copyright 2020 Ross Wightman
3"""
4import torch.nn as nn
5import torch.nn.functional as F
6from .helpers import to_2tuple, to_4tuple
7
8
9class MedianPool2d(nn.Module):
10""" Median pool (usable as median filter when stride=1) module.
11
12Args:
13kernel_size: size of pooling kernel, int or 2-tuple
14stride: pool stride, int or 2-tuple
15padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
16same: override padding and enforce same padding, boolean
17"""
18def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
19super(MedianPool2d, self).__init__()
20self.k = to_2tuple(kernel_size)
21self.stride = to_2tuple(stride)
22self.padding = to_4tuple(padding) # convert to l, r, t, b
23self.same = same
24
25def _padding(self, x):
26if self.same:
27ih, iw = x.size()[2:]
28if ih % self.stride[0] == 0:
29ph = max(self.k[0] - self.stride[0], 0)
30else:
31ph = max(self.k[0] - (ih % self.stride[0]), 0)
32if iw % self.stride[1] == 0:
33pw = max(self.k[1] - self.stride[1], 0)
34else:
35pw = max(self.k[1] - (iw % self.stride[1]), 0)
36pl = pw // 2
37pr = pw - pl
38pt = ph // 2
39pb = ph - pt
40padding = (pl, pr, pt, pb)
41else:
42padding = self.padding
43return padding
44
45def forward(self, x):
46x = F.pad(x, self._padding(x), mode='reflect')
47x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
48x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
49return x
50