google-research

Форк
0
61 строка · 3.1 Кб
1
--- external/eg3d/training/superresolution.py	2023-03-14 23:39:22.320046811 +0000
2
+++ external_reference/eg3d/training/superresolution.py	2023-03-14 23:28:05.379774778 +0000
3
@@ -12,14 +12,14 @@
4
 "Efficient Geometry-aware 3D Generative Adversarial Networks"."""
5
 
6
 import torch
7
-from training.networks_stylegan2 import Conv2dLayer, SynthesisLayer, ToRGBLayer
8
+from external.stylegan.training.networks_stylegan2_terrain import Conv2dLayer, SynthesisLayer, ToRGBLayer
9
 from torch_utils.ops import upfirdn2d
10
 from torch_utils import persistence
11
 from torch_utils import misc
12
 
13
-from training.networks_stylegan2 import SynthesisBlock
14
+from external.stylegan.training.networks_stylegan2_terrain import SynthesisBlock
15
 import numpy as np
16
-from training.networks_stylegan3 import SynthesisLayer as AFSynthesisLayer
17
+from external.stylegan.training.networks_stylegan3_sky import SynthesisLayer as AFSynthesisLayer
18
 
19
 
20
 #----------------------------------------------------------------------------
21
@@ -57,25 +57,29 @@
22
 
23
 #----------------------------------------------------------------------------
24
 
25
-# for 256x256 generation
26
+# for 256x256 generation -- modified to support RGBD input and output in to_rgb
27
+# branch and to ignore style code input
28
 @persistence.persistent_class
29
 class SuperresolutionHybrid4X(torch.nn.Module):
30
     def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias,
31
                 num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE
32
-                **block_kwargs):
33
+                ignore_w=True, **block_kwargs):
34
         super().__init__()
35
         assert img_resolution == 256
36
         use_fp16 = sr_num_fp16_res > 0
37
         self.sr_antialias = sr_antialias
38
         self.input_resolution = 128
39
         self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128,
40
-                img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
41
+                img_channels=4, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
42
         self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256,
43
-                img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
44
+                img_channels=4, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
45
+        self.ignore_w = ignore_w
46
         self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
47
 
48
     def forward(self, rgb, x, ws, **block_kwargs):
49
         ws = ws[:, -1:, :].repeat(1, 3, 1)
50
+        if self.ignore_w:
51
+            ws = torch.zeros_like(ws) # input to affine layer with bias=1
52
 
53
         if x.shape[-1] < self.input_resolution:
54
             x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
55
@@ -289,4 +293,4 @@
56
         x, rgb = self.block1(x, rgb, ws, **block_kwargs)
57
         return rgb
58
 
59
-#----------------------------------------------------------------------------
60
\ No newline at end of file
61
+#----------------------------------------------------------------------------
62

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.