pytorch
240 строк · 9.5 Кб
1import torch
2import torch.nn.functional as F
3
4import numpy as np
5from typing import List, Optional
6
7from .expanded_weights_utils import \
8set_grad_sample_if_exists, unpack_expanded_weight_or_tensor
9
10THRESHOLD = 32
11
12
13def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt):
14if func == F.conv1d:
15return conv1dOpt
16if func == F.conv2d:
17return conv2dOpt
18else:
19assert func == F.conv3d
20return conv3dOpt
21
22
23def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
24args = expanded_args_and_kwargs[:len(expanded_args_and_kwargs) - len(kwarg_names)]
25kwargs = expanded_args_and_kwargs[len(expanded_args_and_kwargs) - len(kwarg_names):]
26kwargs = dict(zip(kwarg_names, kwargs))
27
28return conv_normalizer(*args, **kwargs)
29
30
31def conv_normalizer(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
32return (input, weight), {'bias': bias, 'stride': stride, 'padding': padding, 'dilation': dilation, 'groups': groups}
33
34
35def conv_input_for_string_padding(func, padding_style, input, dilation, kernel_size):
36if padding_style == "valid":
37return input
38else:
39padding = int_padding_for_string_padding(func, padding_style, dilation, kernel_size)
40return F.pad(input, padding)
41
42
43def int_padding_for_string_padding(func, padding_style, dilation, kernel_size):
44def get_dilation(i):
45return dilation[i] if isinstance(dilation, tuple) else dilation
46
47if padding_style == "same":
48padding: List[int] = []
49# F.pad needs the padding in reverse order from what conv expects
50for i in range(conv_picker(func, 0, 1, 2), -1, -1):
51padding += conv_padding_for_same(get_dilation(i), kernel_size[i])
52return padding
53elif padding_style == "valid":
54return conv_picker(func, 2, 4, 6) * (0,)
55else:
56raise RuntimeError(f"got padding type of {padding_style}, only accept 'same' or 'valid'")
57
58
59def conv_padding_for_same(dilation, kernel_size):
60total_pad = dilation * (kernel_size - 1)
61left_pad = total_pad // 2
62right_pad = total_pad - left_pad
63return left_pad, right_pad
64
65
66def conv_backward(func, ctx, grad_output):
67
68def weight_grad_sample(weight):
69if (batch_size < THRESHOLD and groups == 1):
70return conv_group_weight_grad_sample(ctx.input, grad_output, weight_shape, stride, padding, dilation, batch_size, func)
71else:
72return conv_unfold_weight_grad_sample(ctx.input, grad_output, weight_shape, kernel_size,
73stride, padding, dilation, groups, func)
74
75def expand(param):
76if isinstance(param, int):
77return conv_picker(func, (param,), (param, param), (param, param, param))
78else:
79return param
80
81def calc_total_padding(func, was_same, padding, dilation, kernel_size):
82if was_same:
83all_padding = int_padding_for_string_padding(func, "same", dilation, kernel_size)
84# F.pad needs the padding in reverse order from what conv expects
85total_padding = tuple(all_padding[i] + all_padding[i - 1] for i in range(len(all_padding) - 1, -1, -2))
86return total_padding
87else:
88return tuple(2 * pad for pad in padding)
89
90weight_shape = ctx.weight.shape
91stride, padding, dilation, groups = expand(ctx.stride), expand(ctx.padding), expand(ctx.dilation), ctx.groups
92
93kernel_size = []
94for i in range(2, conv_picker(func, 3, 4, 5)):
95kernel_size.append(weight_shape[i])
96
97batch_size = ctx.batch_size
98results: List[Optional[torch.Tensor]] = []
99results.append(None) # for kwarg names
100results.append(None) # for op reference
101
102# "same" padding may give uneven padding on either side so we need to separate the "padding" attr and total padding
103total_padding = calc_total_padding(func, ctx.was_same_padding, padding, dilation, kernel_size)
104
105if ctx.input_required_grad:
106output_padding = []
107input_dims = conv_picker(func, 1, 2, 3)
108for i in range(input_dims):
109input_dim = ctx.orig_input_shape[2 + i]
110output_padding.append((total_padding[i] + input_dim - (kernel_size[i] * dilation[i] - dilation[i] + 1)) % stride[i])
111weight_ = unpack_expanded_weight_or_tensor(ctx.weight)
112transpose_func = conv_picker(func, F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d)
113out = transpose_func(grad_output, weight_, None, stride, padding, tuple(output_padding), groups, dilation)
114
115if ctx.was_same_padding:
116for i in range(len(total_padding)):
117out = torch.narrow(out, 2 + i, total_padding[i] // 2, ctx.orig_input_shape[2 + i])
118
119results.append(out)
120else:
121results.append(None)
122# weight and bias don't compute batched gradients; no other arguments are differentiable
123results = results + [None] * 6
124
125# set grad_sample field for weight and bias with per sample gradients
126set_grad_sample_if_exists(ctx.weight, weight_grad_sample)
127set_grad_sample_if_exists(ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2))
128return tuple(results)
129
130
131def conv_unfold_weight_grad_sample(input, grad_output, weight_shape, kernel_size, stride, padding, dilation, groups, func):
132n = input.shape[0]
133in_channels = input.shape[1]
134
135unfold_func = conv_picker(
136func,
137lambda: F.unfold(input.unsqueeze(-2),
138kernel_size=(1, kernel_size[0]),
139dilation=(1, dilation[0]),
140padding=(0, padding[0]),
141stride=(1, stride[0])),
142lambda: F.unfold(input, kernel_size, dilation=dilation, padding=padding, stride=stride),
143lambda: unfold3d(input, kernel_size, padding, stride, dilation)
144)
145
146input = unfold_func()
147grad_output = grad_output.reshape(n, -1, input.shape[-1])
148
149# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
150weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
151# rearrange the above tensor and extract diagonals.
152weight_grad_sample = weight_grad_sample.view(
153n,
154groups,
155-1,
156groups,
157int(in_channels / groups),
158np.prod(kernel_size),
159)
160weight_grad_sample = torch.einsum("ngrg...->ngr...", weight_grad_sample).contiguous()
161shape = [n] + list(weight_shape)
162weight_grad_sample = weight_grad_sample.view(shape)
163return weight_grad_sample
164
165
166def conv_group_weight_grad_sample(input, grad_output, weight_shape, stride, padding, dilation, batch_size, func):
167I = input.shape[1]
168O = grad_output.shape[1]
169
170input_ = input.transpose(0, 1)
171grad_output_ = grad_output.view(grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:])
172
173weight_grad_sample = func(input_, grad_output_, None, stride=dilation, padding=padding, dilation=stride, groups=batch_size)
174input_dims = conv_picker(func, 3, 4, 5)
175for i in range(2, input_dims):
176weight_grad_sample = weight_grad_sample.narrow(i, 0, weight_shape[i])
177weight_grad_sample = weight_grad_sample.view(I, batch_size, O, *weight_grad_sample.shape[2:])
178weight_grad_sample = weight_grad_sample.movedim(0, 2)
179return weight_grad_sample
180
181
182def unfold3d(
183tensor,
184kernel_size,
185padding,
186stride,
187dilation,
188):
189r"""
190Extract sliding local blocks from an batched input tensor.
191
192:class:`torch.nn.Unfold` only supports 4D inputs (batched image-like tensors).
193This method implements the same action for 5D inputs
194Args:
195tensor: An input tensor of shape ``(B, C, D, H, W)``.
196kernel_size: the size of the sliding blocks
197padding: implicit zero padding to be added on both sides of input
198stride: the stride of the sliding blocks in the input spatial dimensions
199dilation: the spacing between the kernel points.
200Returns:
201A tensor of shape ``(B, C * np.prod(kernel_size), L)``, where L - output spatial dimensions.
202See :class:`torch.nn.Unfold` for more details
203Example:
204>>> # xdoctest: +SKIP
205>>> B, C, D, H, W = 3, 4, 5, 6, 7
206>>> tensor = torch.arange(1, B * C * D * H * W + 1.).view(B, C, D, H, W)
207>>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape
208torch.Size([3, 32, 120])
209"""
210if len(tensor.shape) != 5:
211raise ValueError(
212f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}"
213)
214
215if dilation != (1, 1, 1):
216raise NotImplementedError(f"dilation={dilation} not supported.")
217
218batch_size, channels, _, _, _ = tensor.shape
219
220# Input shape: (B, C, D, H, W)
221tensor = F.pad(
222tensor, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0])
223)
224# Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0])
225
226tensor = tensor.unfold(dimension=2, size=kernel_size[0], step=stride[0])
227tensor = tensor.unfold(dimension=3, size=kernel_size[1], step=stride[1])
228tensor = tensor.unfold(dimension=4, size=kernel_size[2], step=stride[2])
229# Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2])
230# For D_out, H_out, W_out definitions see :class:`torch.nn.Unfold`
231
232tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7)
233# Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
234
235tensor = tensor.reshape(batch_size, -1, channels * np.prod(kernel_size)).transpose(
2361, 2
237)
238# Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2]
239
240return tensor
241