pytorch-image-models

Форк
0
79 строк · 2.8 Кб
1
""" Padding Helpers
2

3
Hacked together by / Copyright 2020 Ross Wightman
4
"""
5
import math
6
from typing import List, Tuple
7

8
import torch
9
import torch.nn.functional as F
10

11

12
# Calculate symmetric padding for a convolution
13
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
14
    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
15
    return padding
16

17

18
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
19
def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int):
20
    if isinstance(x, torch.Tensor):
21
        return torch.clamp(((x / stride).ceil() - 1) * stride + (kernel_size - 1) * dilation + 1 - x, min=0)
22
    else:
23
        return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0)
24

25

26
# Can SAME padding for given args be done statically?
27
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
28
    return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
29

30

31
def pad_same_arg(
32
        input_size: List[int],
33
        kernel_size: List[int],
34
        stride: List[int],
35
        dilation: List[int] = (1, 1),
36
) -> List[int]:
37
    ih, iw = input_size
38
    kh, kw = kernel_size
39
    pad_h = get_same_padding(ih, kh, stride[0], dilation[0])
40
    pad_w = get_same_padding(iw, kw, stride[1], dilation[1])
41
    return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
42

43

44
# Dynamically pad input x with 'SAME' padding for conv with specified args
45
def pad_same(
46
        x,
47
        kernel_size: List[int],
48
        stride: List[int],
49
        dilation: List[int] = (1, 1),
50
        value: float = 0,
51
):
52
    ih, iw = x.size()[-2:]
53
    pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
54
    pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
55
    x = F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), value=value)
56
    return x
57

58

59
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
60
    dynamic = False
61
    if isinstance(padding, str):
62
        # for any string padding, the padding will be calculated for you, one of three ways
63
        padding = padding.lower()
64
        if padding == 'same':
65
            # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
66
            if is_static_pad(kernel_size, **kwargs):
67
                # static case, no extra overhead
68
                padding = get_padding(kernel_size, **kwargs)
69
            else:
70
                # dynamic 'SAME' padding, has runtime/GPU memory overhead
71
                padding = 0
72
                dynamic = True
73
        elif padding == 'valid':
74
            # 'VALID' padding, same as padding=0
75
            padding = 0
76
        else:
77
            # Default to PyTorch style 'same'-ish symmetric padding
78
            padding = get_padding(kernel_size, **kwargs)
79
    return padding, dynamic
80

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.