pytorch-image-models

Форк
0
73 строки · 3.0 Кб
1
""" AvgPool2d w/ Same Padding
2

3
Hacked together by / Copyright 2020 Ross Wightman
4
"""
5
import torch
6
import torch.nn as nn
7
import torch.nn.functional as F
8
from typing import List, Tuple, Optional
9

10
from .helpers import to_2tuple
11
from .padding import pad_same, get_padding_value
12

13

14
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
15
                    ceil_mode: bool = False, count_include_pad: bool = True):
16
    # FIXME how to deal with count_include_pad vs not for external padding?
17
    x = pad_same(x, kernel_size, stride)
18
    return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
19

20

21
class AvgPool2dSame(nn.AvgPool2d):
22
    """ Tensorflow like 'SAME' wrapper for 2D average pooling
23
    """
24
    def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
25
        kernel_size = to_2tuple(kernel_size)
26
        stride = to_2tuple(stride)
27
        super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
28

29
    def forward(self, x):
30
        x = pad_same(x, self.kernel_size, self.stride)
31
        return F.avg_pool2d(
32
            x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
33

34

35
def max_pool2d_same(
36
        x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
37
        dilation: List[int] = (1, 1), ceil_mode: bool = False):
38
    x = pad_same(x, kernel_size, stride, value=-float('inf'))
39
    return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
40

41

42
class MaxPool2dSame(nn.MaxPool2d):
43
    """ Tensorflow like 'SAME' wrapper for 2D max pooling
44
    """
45
    def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
46
        kernel_size = to_2tuple(kernel_size)
47
        stride = to_2tuple(stride)
48
        dilation = to_2tuple(dilation)
49
        super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
50

51
    def forward(self, x):
52
        x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
53
        return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
54

55

56
def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
57
    stride = stride or kernel_size
58
    padding = kwargs.pop('padding', '')
59
    padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
60
    if is_dynamic:
61
        if pool_type == 'avg':
62
            return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
63
        elif pool_type == 'max':
64
            return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
65
        else:
66
            assert False, f'Unsupported pool type {pool_type}'
67
    else:
68
        if pool_type == 'avg':
69
            return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
70
        elif pool_type == 'max':
71
            return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
72
        else:
73
            assert False, f'Unsupported pool type {pool_type}'
74

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.