pytorch

Форк
0
240 строк · 9.5 Кб
1
import torch
2
import torch.nn.functional as F
3

4
import numpy as np
5
from typing import List, Optional
6

7
from .expanded_weights_utils import \
8
    set_grad_sample_if_exists, unpack_expanded_weight_or_tensor
9

10
THRESHOLD = 32
11

12

13
def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt):
14
    if func == F.conv1d:
15
        return conv1dOpt
16
    if func == F.conv2d:
17
        return conv2dOpt
18
    else:
19
        assert func == F.conv3d
20
        return conv3dOpt
21

22

23
def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
24
    args = expanded_args_and_kwargs[:len(expanded_args_and_kwargs) - len(kwarg_names)]
25
    kwargs = expanded_args_and_kwargs[len(expanded_args_and_kwargs) - len(kwarg_names):]
26
    kwargs = dict(zip(kwarg_names, kwargs))
27

28
    return conv_normalizer(*args, **kwargs)
29

30

31
def conv_normalizer(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
32
    return (input, weight), {'bias': bias, 'stride': stride, 'padding': padding, 'dilation': dilation, 'groups': groups}
33

34

35
def conv_input_for_string_padding(func, padding_style, input, dilation, kernel_size):
36
    if padding_style == "valid":
37
        return input
38
    else:
39
        padding = int_padding_for_string_padding(func, padding_style, dilation, kernel_size)
40
        return F.pad(input, padding)
41

42

43
def int_padding_for_string_padding(func, padding_style, dilation, kernel_size):
44
    def get_dilation(i):
45
        return dilation[i] if isinstance(dilation, tuple) else dilation
46

47
    if padding_style == "same":
48
        padding: List[int] = []
49
        # F.pad needs the padding in reverse order from what conv expects
50
        for i in range(conv_picker(func, 0, 1, 2), -1, -1):
51
            padding += conv_padding_for_same(get_dilation(i), kernel_size[i])
52
        return padding
53
    elif padding_style == "valid":
54
        return conv_picker(func, 2, 4, 6) * (0,)
55
    else:
56
        raise RuntimeError(f"got padding type of {padding_style}, only accept 'same' or 'valid'")
57

58

59
def conv_padding_for_same(dilation, kernel_size):
60
    total_pad = dilation * (kernel_size - 1)
61
    left_pad = total_pad // 2
62
    right_pad = total_pad - left_pad
63
    return left_pad, right_pad
64

65

66
def conv_backward(func, ctx, grad_output):
67

68
    def weight_grad_sample(weight):
69
        if (batch_size < THRESHOLD and groups == 1):
70
            return conv_group_weight_grad_sample(ctx.input, grad_output, weight_shape, stride, padding, dilation, batch_size, func)
71
        else:
72
            return conv_unfold_weight_grad_sample(ctx.input, grad_output, weight_shape, kernel_size,
73
                                                  stride, padding, dilation, groups, func)
74

75
    def expand(param):
76
        if isinstance(param, int):
77
            return conv_picker(func, (param,), (param, param), (param, param, param))
78
        else:
79
            return param
80

81
    def calc_total_padding(func, was_same, padding, dilation, kernel_size):
82
        if was_same:
83
            all_padding = int_padding_for_string_padding(func, "same", dilation, kernel_size)
84
            # F.pad needs the padding in reverse order from what conv expects
85
            total_padding = tuple(all_padding[i] + all_padding[i - 1] for i in range(len(all_padding) - 1, -1, -2))
86
            return total_padding
87
        else:
88
            return tuple(2 * pad for pad in padding)
89

90
    weight_shape = ctx.weight.shape
91
    stride, padding, dilation, groups = expand(ctx.stride), expand(ctx.padding), expand(ctx.dilation), ctx.groups
92

93
    kernel_size = []
94
    for i in range(2, conv_picker(func, 3, 4, 5)):
95
        kernel_size.append(weight_shape[i])
96

97
    batch_size = ctx.batch_size
98
    results: List[Optional[torch.Tensor]] = []
99
    results.append(None)  # for kwarg names
100
    results.append(None)  # for op reference
101

102
    # "same" padding may give uneven padding on either side so we need to separate the "padding" attr and total padding
103
    total_padding = calc_total_padding(func, ctx.was_same_padding, padding, dilation, kernel_size)
104

105
    if ctx.input_required_grad:
106
        output_padding = []
107
        input_dims = conv_picker(func, 1, 2, 3)
108
        for i in range(input_dims):
109
            input_dim = ctx.orig_input_shape[2 + i]
110
            output_padding.append((total_padding[i] + input_dim - (kernel_size[i] * dilation[i] - dilation[i] + 1)) % stride[i])
111
        weight_ = unpack_expanded_weight_or_tensor(ctx.weight)
112
        transpose_func = conv_picker(func, F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d)
113
        out = transpose_func(grad_output, weight_, None, stride, padding, tuple(output_padding), groups, dilation)
114

115
        if ctx.was_same_padding:
116
            for i in range(len(total_padding)):
117
                out = torch.narrow(out, 2 + i, total_padding[i] // 2, ctx.orig_input_shape[2 + i])
118

119
        results.append(out)
120
    else:
121
        results.append(None)
122
    # weight and bias don't compute batched gradients; no other arguments are differentiable
123
    results = results + [None] * 6
124

125
    # set grad_sample field for weight and bias with per sample gradients
126
    set_grad_sample_if_exists(ctx.weight, weight_grad_sample)
127
    set_grad_sample_if_exists(ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2))
128
    return tuple(results)
129

130

131
def conv_unfold_weight_grad_sample(input, grad_output, weight_shape, kernel_size, stride, padding, dilation, groups, func):
132
    n = input.shape[0]
133
    in_channels = input.shape[1]
134

135
    unfold_func = conv_picker(
136
        func,
137
        lambda: F.unfold(input.unsqueeze(-2),
138
                         kernel_size=(1, kernel_size[0]),
139
                         dilation=(1, dilation[0]),
140
                         padding=(0, padding[0]),
141
                         stride=(1, stride[0])),
142
        lambda: F.unfold(input, kernel_size, dilation=dilation, padding=padding, stride=stride),
143
        lambda: unfold3d(input, kernel_size, padding, stride, dilation)
144
    )
145

146
    input = unfold_func()
147
    grad_output = grad_output.reshape(n, -1, input.shape[-1])
148

149
    # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
150
    weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
151
    # rearrange the above tensor and extract diagonals.
152
    weight_grad_sample = weight_grad_sample.view(
153
        n,
154
        groups,
155
        -1,
156
        groups,
157
        int(in_channels / groups),
158
        np.prod(kernel_size),
159
    )
160
    weight_grad_sample = torch.einsum("ngrg...->ngr...", weight_grad_sample).contiguous()
161
    shape = [n] + list(weight_shape)
162
    weight_grad_sample = weight_grad_sample.view(shape)
163
    return weight_grad_sample
164

165

166
def conv_group_weight_grad_sample(input, grad_output, weight_shape, stride, padding, dilation, batch_size, func):
167
    I = input.shape[1]
168
    O = grad_output.shape[1]
169

170
    input_ = input.transpose(0, 1)
171
    grad_output_ = grad_output.view(grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:])
172

173
    weight_grad_sample = func(input_, grad_output_, None, stride=dilation, padding=padding, dilation=stride, groups=batch_size)
174
    input_dims = conv_picker(func, 3, 4, 5)
175
    for i in range(2, input_dims):
176
        weight_grad_sample = weight_grad_sample.narrow(i, 0, weight_shape[i])
177
    weight_grad_sample = weight_grad_sample.view(I, batch_size, O, *weight_grad_sample.shape[2:])
178
    weight_grad_sample = weight_grad_sample.movedim(0, 2)
179
    return weight_grad_sample
180

181

182
def unfold3d(
183
    tensor,
184
    kernel_size,
185
    padding,
186
    stride,
187
    dilation,
188
):
189
    r"""
190
    Extract sliding local blocks from an batched input tensor.
191

192
    :class:`torch.nn.Unfold` only supports 4D inputs (batched image-like tensors).
193
    This method implements the same action for 5D inputs
194
    Args:
195
        tensor: An input tensor of shape ``(B, C, D, H, W)``.
196
        kernel_size: the size of the sliding blocks
197
        padding: implicit zero padding to be added on both sides of input
198
        stride: the stride of the sliding blocks in the input spatial dimensions
199
        dilation: the spacing between the kernel points.
200
    Returns:
201
        A tensor of shape ``(B, C * np.prod(kernel_size), L)``, where L - output spatial dimensions.
202
        See :class:`torch.nn.Unfold` for more details
203
    Example:
204
        >>> # xdoctest: +SKIP
205
        >>> B, C, D, H, W = 3, 4, 5, 6, 7
206
        >>> tensor = torch.arange(1, B * C * D * H * W + 1.).view(B, C, D, H, W)
207
        >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape
208
        torch.Size([3, 32, 120])
209
    """
210
    if len(tensor.shape) != 5:
211
        raise ValueError(
212
            f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}"
213
        )
214

215
    if dilation != (1, 1, 1):
216
        raise NotImplementedError(f"dilation={dilation} not supported.")
217

218
    batch_size, channels, _, _, _ = tensor.shape
219

220
    # Input shape: (B, C, D, H, W)
221
    tensor = F.pad(
222
        tensor, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0])
223
    )
224
    # Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0])
225

226
    tensor = tensor.unfold(dimension=2, size=kernel_size[0], step=stride[0])
227
    tensor = tensor.unfold(dimension=3, size=kernel_size[1], step=stride[1])
228
    tensor = tensor.unfold(dimension=4, size=kernel_size[2], step=stride[2])
229
    # Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2])
230
    # For D_out, H_out, W_out definitions see :class:`torch.nn.Unfold`
231

232
    tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7)
233
    # Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
234

235
    tensor = tensor.reshape(batch_size, -1, channels * np.prod(kernel_size)).transpose(
236
        1, 2
237
    )
238
    # Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2]
239

240
    return tensor
241

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.