google-research

Форк
0
197 строк · 5.3 Кб
1
--- external/gsn/models/layers.py	2023-03-28 23:42:27.913662937 +0000
2
+++ external_reference/gsn/models/layers.py	2023-03-28 23:41:48.297127705 +0000
3
@@ -4,9 +4,31 @@
4
 from torch import nn
5
 import torch.nn.functional as F
6
 
7
-from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
8
+from torch_utils.ops import upfirdn2d
9
+from torch_utils.ops import bias_act
10
+from torch_utils import persistence
11
+
12
+# CHANGED: reimplement using torch_utils
13
+def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5):
14
+    x = bias_act.bias_act(input, bias, act='lrelu', alpha=negative_slope, gain=scale)
15
+    return x
16
+
17
+@persistence.persistent_class
18
+class FusedLeakyReLU(nn.Module):
19
+    def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
20
+        super().__init__()
21
+        if bias:
22
+            self.bias = nn.Parameter(torch.zeros(channel))
23
+        else:
24
+            self.bias = None
25
 
26
+        self.negative_slope = negative_slope
27
+        self.scale = scale
28
 
29
+    def forward(self, input):
30
+        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
31
+
32
+@persistence.persistent_class
33
 class PixelNorm(nn.Module):
34
     """Pixel normalization layer.
35
 
36
@@ -20,6 +42,7 @@
37
         return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
38
 
39
 
40
+@persistence.persistent_class
41
 class ConstantInput(nn.Module):
42
     """Constant input layer.
43
 
44
@@ -58,6 +81,7 @@
45
     return k
46
 
47
 
48
+@persistence.persistent_class
49
 class Blur(nn.Module):
50
     """Blur layer.
51
 
52
@@ -81,7 +105,7 @@
53
     def __init__(self, kernel, pad, upsample_factor=1):
54
         super().__init__()
55
 
56
-        kernel = make_kernel(kernel)
57
+        kernel = upfirdn2d.setup_filter(kernel)
58
 
59
         if upsample_factor > 1:
60
             kernel = kernel * (upsample_factor ** 2)
61
@@ -90,10 +114,11 @@
62
         self.pad = pad
63
 
64
     def forward(self, input):
65
-        out = upfirdn2d(input, self.kernel, pad=self.pad)
66
+        out = upfirdn2d.upfirdn2d(input, self.kernel, padding=self.pad)
67
         return out
68
 
69
 
70
+@persistence.persistent_class
71
 class Upsample(nn.Module):
72
     """Upsampling layer.
73
 
74
@@ -112,19 +137,21 @@
75
         super().__init__()
76
 
77
         self.factor = factor
78
-        kernel = make_kernel(kernel) * (factor ** 2)
79
+        kernel = upfirdn2d.setup_filter(kernel)  # * (factor ** 2) is handled in upsampled2d gain argument
80
+        # kernel = make_kernel(kernel) * (factor ** 2)
81
         self.register_buffer("kernel", kernel)
82
 
83
-        p = kernel.shape[0] - factor
84
-        pad0 = (p + 1) // 2 + factor - 1
85
-        pad1 = p // 2
86
-        self.pad = (pad0, pad1)
87
+        # p = kernel.shape[0] - factor
88
+        # pad0 = (p + 1) // 2 + factor - 1
89
+        # pad1 = p // 2
90
+        self.pad = (0, 0) # upfirdn2d.upsample2d handles additional padding
91
 
92
     def forward(self, input):
93
-        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
94
+        out = upfirdn2d.upsample2d(input, self.kernel, up=self.factor, padding=self.pad)
95
         return out
96
 
97
 
98
+@persistence.persistent_class
99
 class Downsample(nn.Module):
100
     """Downsampling layer.
101
 
102
@@ -143,19 +170,22 @@
103
         super().__init__()
104
 
105
         self.factor = factor
106
-        kernel = make_kernel(kernel)
107
+        kernel = upfirdn2d.setup_filter(kernel) # make_kernel(kernel)
108
         self.register_buffer("kernel", kernel)
109
 
110
+        # downsample needs padding
111
         p = kernel.shape[0] - factor
112
         pad0 = (p + 1) // 2
113
         pad1 = p // 2
114
         self.pad = (pad0, pad1)
115
+        # self.pad = (0, 0)
116
 
117
     def forward(self, input):
118
-        out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
119
+        out = upfirdn2d.upfirdn2d(input, self.kernel, up=1, down=self.factor, padding=self.pad)
120
         return out
121
 
122
 
123
+@persistence.persistent_class
124
 class EqualLinear(nn.Module):
125
     """Linear layer with equalized learning rate.
126
 
127
@@ -207,6 +237,7 @@
128
         return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
129
 
130
 
131
+@persistence.persistent_class
132
 class EqualConv2d(nn.Module):
133
     """2D convolution layer with equalized learning rate.
134
 
135
@@ -256,6 +287,7 @@
136
         )
137
 
138
 
139
+@persistence.persistent_class
140
 class EqualConvTranspose2d(nn.Module):
141
     """2D transpose convolution layer with equalized learning rate.
142
 
143
@@ -315,6 +347,7 @@
144
         )
145
 
146
 
147
+@persistence.persistent_class
148
 class ConvLayer2d(nn.Sequential):
149
     def __init__(
150
         self,
151
@@ -367,6 +400,7 @@
152
         super().__init__(*layers)
153
 
154
 
155
+@persistence.persistent_class
156
 class ConvResBlock2d(nn.Module):
157
     """2D convolutional residual block with equalized learning rate.
158
 
159
@@ -417,6 +451,7 @@
160
         return out
161
 
162
 
163
+@persistence.persistent_class
164
 class ModulationLinear(nn.Module):
165
     """Linear modulation layer.
166
 
167
@@ -497,6 +532,7 @@
168
         return out
169
 
170
 
171
+@persistence.persistent_class
172
 class ModulatedConv2d(nn.Module):
173
     """2D convolutional modulation layer.
174
 
175
@@ -631,6 +667,7 @@
176
         return out
177
 
178
 
179
+@persistence.persistent_class
180
 class ToRGB(nn.Module):
181
     """Output aggregation layer.
182
 
183
@@ -675,6 +712,7 @@
184
         return out
185
 
186
 
187
+@persistence.persistent_class
188
 class ConvRenderBlock2d(nn.Module):
189
     """2D convolutional neural rendering block.
190
 
191
@@ -742,6 +780,7 @@
192
         return x, rgb
193
 
194
 
195
+@persistence.persistent_class
196
 class PositionalEncoding(nn.Module):
197
     """Positional encoding layer.
198
 
199

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.