lama

Форк
0
/
multidilated_conv.py 
98 строк · 4.0 Кб
1
import torch
2
import torch.nn as nn
3
import random
4
from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
5

6
class MultidilatedConv(nn.Module):
7
    def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True,
8
                 shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs):
9
        super().__init__()
10
        convs = []
11
        self.equal_dim = equal_dim
12
        assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode
13
        if comb_mode in ('cat_out', 'cat_both'):
14
            self.cat_out = True
15
            if equal_dim:
16
                assert out_dim % dilation_num == 0
17
                out_dims = [out_dim // dilation_num] * dilation_num
18
                self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [])
19
            else:
20
                out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
21
                out_dims.append(out_dim - sum(out_dims))
22
                index = []
23
                starts = [0] + out_dims[:-1]
24
                lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)]
25
                for i in range(out_dims[-1]):
26
                    for j in range(dilation_num):
27
                        index += list(range(starts[j], starts[j] + lengths[j]))
28
                        starts[j] += lengths[j]
29
                self.index = index
30
                assert(len(index) == out_dim)
31
            self.out_dims = out_dims
32
        else:
33
            self.cat_out = False
34
            self.out_dims = [out_dim] * dilation_num
35

36
        if comb_mode in ('cat_in', 'cat_both'):
37
            if equal_dim:
38
                assert in_dim % dilation_num == 0
39
                in_dims = [in_dim // dilation_num] * dilation_num
40
            else:
41
                in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
42
                in_dims.append(in_dim - sum(in_dims))
43
            self.in_dims = in_dims
44
            self.cat_in = True
45
        else:
46
            self.cat_in = False
47
            self.in_dims = [in_dim] * dilation_num
48

49
        conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d
50
        dilation = min_dilation
51
        for i in range(dilation_num):
52
            if isinstance(padding, int):
53
                cur_padding = padding * dilation
54
            else:
55
                cur_padding = padding[i]
56
            convs.append(conv_type(
57
                self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs
58
            ))
59
            if i > 0 and shared_weights:
60
                convs[-1].weight = convs[0].weight
61
                convs[-1].bias = convs[0].bias
62
            dilation *= 2
63
        self.convs = nn.ModuleList(convs)
64

65
        self.shuffle_in_channels = shuffle_in_channels
66
        if self.shuffle_in_channels:
67
            # shuffle list as shuffling of tensors is nondeterministic
68
            in_channels_permute = list(range(in_dim))
69
            random.shuffle(in_channels_permute)
70
            # save as buffer so it is saved and loaded with checkpoint
71
            self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute))
72

73
    def forward(self, x):
74
        if self.shuffle_in_channels:
75
            x = x[:, self.in_channels_permute]
76

77
        outs = []
78
        if self.cat_in:
79
            if self.equal_dim:
80
                x = x.chunk(len(self.convs), dim=1)
81
            else:
82
                new_x = []
83
                start = 0
84
                for dim in self.in_dims:
85
                    new_x.append(x[:, start:start+dim])
86
                    start += dim
87
                x = new_x
88
        for i, conv in enumerate(self.convs):
89
            if self.cat_in:
90
                input = x[i]
91
            else:
92
                input = x
93
            outs.append(conv(input))
94
        if self.cat_out:
95
            out = torch.cat(outs, dim=1)[:, self.index]
96
        else:
97
            out = sum(outs)
98
        return out
99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.