lama
98 строк · 4.0 Кб
1import torch
2import torch.nn as nn
3import random
4from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
5
6class MultidilatedConv(nn.Module):
7def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True,
8shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs):
9super().__init__()
10convs = []
11self.equal_dim = equal_dim
12assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode
13if comb_mode in ('cat_out', 'cat_both'):
14self.cat_out = True
15if equal_dim:
16assert out_dim % dilation_num == 0
17out_dims = [out_dim // dilation_num] * dilation_num
18self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [])
19else:
20out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
21out_dims.append(out_dim - sum(out_dims))
22index = []
23starts = [0] + out_dims[:-1]
24lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)]
25for i in range(out_dims[-1]):
26for j in range(dilation_num):
27index += list(range(starts[j], starts[j] + lengths[j]))
28starts[j] += lengths[j]
29self.index = index
30assert(len(index) == out_dim)
31self.out_dims = out_dims
32else:
33self.cat_out = False
34self.out_dims = [out_dim] * dilation_num
35
36if comb_mode in ('cat_in', 'cat_both'):
37if equal_dim:
38assert in_dim % dilation_num == 0
39in_dims = [in_dim // dilation_num] * dilation_num
40else:
41in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
42in_dims.append(in_dim - sum(in_dims))
43self.in_dims = in_dims
44self.cat_in = True
45else:
46self.cat_in = False
47self.in_dims = [in_dim] * dilation_num
48
49conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d
50dilation = min_dilation
51for i in range(dilation_num):
52if isinstance(padding, int):
53cur_padding = padding * dilation
54else:
55cur_padding = padding[i]
56convs.append(conv_type(
57self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs
58))
59if i > 0 and shared_weights:
60convs[-1].weight = convs[0].weight
61convs[-1].bias = convs[0].bias
62dilation *= 2
63self.convs = nn.ModuleList(convs)
64
65self.shuffle_in_channels = shuffle_in_channels
66if self.shuffle_in_channels:
67# shuffle list as shuffling of tensors is nondeterministic
68in_channels_permute = list(range(in_dim))
69random.shuffle(in_channels_permute)
70# save as buffer so it is saved and loaded with checkpoint
71self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute))
72
73def forward(self, x):
74if self.shuffle_in_channels:
75x = x[:, self.in_channels_permute]
76
77outs = []
78if self.cat_in:
79if self.equal_dim:
80x = x.chunk(len(self.convs), dim=1)
81else:
82new_x = []
83start = 0
84for dim in self.in_dims:
85new_x.append(x[:, start:start+dim])
86start += dim
87x = new_x
88for i, conv in enumerate(self.convs):
89if self.cat_in:
90input = x[i]
91else:
92input = x
93outs.append(conv(input))
94if self.cat_out:
95out = torch.cat(outs, dim=1)[:, self.index]
96else:
97out = sum(outs)
98return out
99