pytorch-image-models
73 строки · 3.0 Кб
1""" AvgPool2d w/ Same Padding
2
3Hacked together by / Copyright 2020 Ross Wightman
4"""
5import torch6import torch.nn as nn7import torch.nn.functional as F8from typing import List, Tuple, Optional9
10from .helpers import to_2tuple11from .padding import pad_same, get_padding_value12
13
14def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),15ceil_mode: bool = False, count_include_pad: bool = True):16# FIXME how to deal with count_include_pad vs not for external padding?17x = pad_same(x, kernel_size, stride)18return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)19
20
21class AvgPool2dSame(nn.AvgPool2d):22""" Tensorflow like 'SAME' wrapper for 2D average pooling23"""
24def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):25kernel_size = to_2tuple(kernel_size)26stride = to_2tuple(stride)27super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)28
29def forward(self, x):30x = pad_same(x, self.kernel_size, self.stride)31return F.avg_pool2d(32x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)33
34
35def max_pool2d_same(36x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),37dilation: List[int] = (1, 1), ceil_mode: bool = False):38x = pad_same(x, kernel_size, stride, value=-float('inf'))39return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)40
41
42class MaxPool2dSame(nn.MaxPool2d):43""" Tensorflow like 'SAME' wrapper for 2D max pooling44"""
45def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):46kernel_size = to_2tuple(kernel_size)47stride = to_2tuple(stride)48dilation = to_2tuple(dilation)49super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)50
51def forward(self, x):52x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))53return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)54
55
56def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):57stride = stride or kernel_size58padding = kwargs.pop('padding', '')59padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)60if is_dynamic:61if pool_type == 'avg':62return AvgPool2dSame(kernel_size, stride=stride, **kwargs)63elif pool_type == 'max':64return MaxPool2dSame(kernel_size, stride=stride, **kwargs)65else:66assert False, f'Unsupported pool type {pool_type}'67else:68if pool_type == 'avg':69return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)70elif pool_type == 'max':71return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)72else:73assert False, f'Unsupported pool type {pool_type}'74