google-research

Форк
0
258 строк · 11.2 Кб
1
--- external/stylegan/training/dataset.py	2023-04-06 03:45:07.254378718 +0000
2
+++ external_reference/stylegan/training/dataset.py	2023-04-06 03:41:03.338611206 +0000
3
@@ -21,6 +21,9 @@
4
 except ImportError:
5
     pyspng = None
6
 
7
+import random
8
+from utils import midas, camera_util
9
+
10
 #----------------------------------------------------------------------------
11
 
12
 class Dataset(torch.utils.data.Dataset):
13
@@ -85,14 +88,60 @@
14
         return self._raw_idx.size
15
 
16
     def __getitem__(self, idx):
17
-        image = self._load_raw_image(self._raw_idx[idx])
18
+        image_info = self._load_raw_image(self._raw_idx[idx])
19
+        image = image_info['rgb']
20
+        depth = image_info['depth']
21
+        acc = image_info['acc']
22
+        K = image_info['K']
23
+        Rt = image_info['Rt']
24
+        original = image_info['orig']
25
+
26
+        # handle masked image flip
27
         assert isinstance(image, np.ndarray)
28
         assert list(image.shape) == self.image_shape
29
         assert image.dtype == np.uint8
30
         if self._xflip[idx]:
31
             assert image.ndim == 3 # CHW
32
             image = image[:, :, ::-1]
33
-        return image.copy(), self.get_label(idx)
34
+        image = image.copy()
35
+
36
+        # handle original image flip
37
+        assert isinstance(original, np.ndarray)
38
+        assert list(original.shape) == self.image_shape
39
+        assert original.dtype == np.uint8
40
+        if self._xflip[idx]:
41
+            assert original.ndim == 3 # CHW
42
+            original = original[:, :, ::-1]
43
+        original = original.copy()
44
+
45
+        # handle depth flip
46
+        assert isinstance(depth, np.ndarray)
47
+        assert list(depth.shape)[1:] == self.image_shape[1:] # depth has one channel
48
+        if self._xflip[idx]:
49
+            assert depth.ndim == 3 # CHW
50
+            depth = depth[:, :, ::-1]
51
+        depth = depth.copy()
52
+
53
+        # handle mask flip
54
+        assert isinstance(acc, np.ndarray)
55
+        assert list(acc.shape)[1:] == self.image_shape[1:] # acc has one channel
56
+        if self._xflip[idx]:
57
+            assert acc.ndim == 3 # CHW
58
+            acc = acc[:, :, ::-1]
59
+        acc = acc.copy()
60
+
61
+        # check intrinisics and extrinsics
62
+        assert isinstance(K, np.ndarray)
63
+        assert list(K.shape) == [3, 3]
64
+        assert isinstance(Rt, np.ndarray)
65
+        assert list(Rt.shape) == [4, 4]
66
+
67
+        # get flipped images
68
+        image_info['rgb'] = image
69
+        image_info['orig'] = original
70
+        image_info['depth'] = depth
71
+        image_info['acc'] = acc
72
+        return image_info, self.get_label(idx)
73
 
74
     def get_label(self, idx):
75
         label = self._get_raw_labels()[self._raw_idx[idx]]
76
@@ -157,14 +206,45 @@
77
     def __init__(self,
78
         path,                   # Path to directory or zip.
79
         resolution      = None, # Ensure specific resolution, None = highest available.
80
+        pose_path       = None, # Path to training pose distribution
81
+        depth_scale     = 16,   # scale factor for depth
82
+        depth_clip      = 20,   # clip all depths above this value
83
+        use_disp        = True, # use inverse depth if true
84
+        fov_mean        = 60,   # intrinsics mean FOV
85
+        fov_std         = 0,    # intrinsics std FOV
86
+        mask_downsample = 'antialias', # downsample mode for mask (antialias is softer boundary)
87
         **super_kwargs,         # Additional arguments for the Dataset base class.
88
     ):
89
         self._path = path
90
         self._zipfile = None
91
+        self._depth_path = path.replace('img', 'dpt_depth')
92
+        self._seg_path = path.replace('img', 'dpt_sky')
93
+        self.depth_scale = depth_scale
94
+        self.depth_clip = depth_clip
95
+        self.use_disp = use_disp
96
+        self.pose_path = pose_path
97
+        if self.pose_path is not None:
98
+            data = torch.load(self.pose_path)
99
+            self.Rt = data['Rts'].float().numpy()
100
+            self.cameras = data['cameras']
101
+        self.fov_mean = fov_mean
102
+        self.fov_std = fov_std
103
+        self.mask_downsample = mask_downsample
104
 
105
         if os.path.isdir(self._path):
106
             self._type = 'dir'
107
-            self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
108
+            if os.path.isfile(self._path + '_cache.txt'):
109
+                with open(self._path + '_cache.txt') as cache:
110
+                    self._all_fnames = set([line.strip() for line in cache])
111
+            else:
112
+                print("Walking dataset...")
113
+                self._all_fnames = [os.path.relpath(os.path.join(root, fname), start=self._path)
114
+                                    for root, _dirs, files in
115
+                                    os.walk(self._path, followlinks=True) for fname in files]
116
+                with open(self._path + '_cache.txt', 'w') as cache:
117
+                    [cache.write("%s\n" % fname) for fname in self._all_fnames]
118
+                self._all_fnames = set(self._all_fnames)
119
+                print("done walking")
120
         elif self._file_ext(self._path) == '.zip':
121
             self._type = 'zip'
122
             self._all_fnames = set(self._get_zipfile().namelist())
123
@@ -177,9 +257,14 @@
124
             raise IOError('No image files found in the specified path')
125
 
126
         name = os.path.splitext(os.path.basename(self._path))[0]
127
-        raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
128
-        if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
129
-            raise IOError('Image files do not match the specified resolution')
130
+        if resolution is not None:
131
+            raw_shape = [len(self._image_fnames)] + [3, resolution, resolution] # list(self._load_raw_image(0).shape)
132
+        else:
133
+            # do not resize it to determine initial shape
134
+            raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0, resize=False)[0].shape)
135
+        # raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
136
+        # if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
137
+        #     raise IOError('Image files do not match the specified resolution')
138
         super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
139
 
140
     @staticmethod
141
@@ -209,17 +294,111 @@
142
     def __getstate__(self):
143
         return dict(super().__getstate__(), _zipfile=None)
144
 
145
-    def _load_raw_image(self, raw_idx):
146
+    def _load_raw_image(self, raw_idx, resize=True):
147
         fname = self._image_fnames[raw_idx]
148
+        ### load image
149
         with self._open_file(fname) as f:
150
-            if pyspng is not None and self._file_ext(fname) == '.png':
151
-                image = pyspng.load(f.read())
152
-            else:
153
-                image = np.array(PIL.Image.open(f))
154
+            image = PIL.Image.open(f).convert('RGB')
155
+        w, h = image.size
156
+        ### load depth map
157
+        depth_path = os.path.join(self._depth_path, fname.replace('png', 'pfm'))
158
+        disp, scale = midas.read_pfm(depth_path)
159
+        # normalize 0 to 1
160
+        disp = np.array(disp)
161
+        dmmin = np.percentile(disp, 1)
162
+        dmmax = np.percentile(disp, 99)
163
+        scaled_disp = (disp-dmmin) / (dmmax-dmmin + 1e-6)
164
+        scaled_disp = np.clip(scaled_disp, 0., 1.) * 255
165
+        disp_img = PIL.Image.fromarray(scaled_disp.astype(np.uint8))
166
+        # disparity mask needs to be done at full resolution
167
+        disp_mask_np = (scaled_disp/255 > 1/self.depth_clip).astype(np.uint8)
168
+        ### load sky mask
169
+        sky_path = os.path.join(self._seg_path, fname.replace('png', 'npz'))
170
+        sky_mask = np.load(sky_path)['sky_mask']
171
+        sky_img = PIL.Image.fromarray(sky_mask * 255)
172
+        ### remove sky from full size image (prevent sky color from leaking when downsampled)
173
+        image_np = np.array(image)
174
+        gray = np.array([128, 128, 128]).reshape(1, 1, -1)
175
+        image_gray = (image_np * sky_mask[..., None] + gray * (1-sky_mask[..., None])).astype(np.uint8)
176
+        image_gray = PIL.Image.fromarray(image_gray)
177
+
178
+        if resize:
179
+            # note: the input images should be square
180
+            assert(image.size[0] == image.size[1])
181
+            assert(self.image_shape[1] == self.image_shape[2])
182
+            target_size = self.image_shape[1:]
183
+            if image.size != target_size:
184
+                image = image.resize(target_size, PIL.Image.ANTIALIAS)
185
+                image_gray = image_gray.resize(target_size, PIL.Image.ANTIALIAS)
186
+                disp_img = disp_img.resize(target_size, PIL.Image.ANTIALIAS)
187
+                sky_img = sky_img.resize(target_size, PIL.Image.NEAREST if self.mask_downsample=='nearest' else PIL.Image.ANTIALIAS) 
188
+
189
+        # handle image dimensions
190
+        image = np.array(image)
191
         if image.ndim == 2:
192
             image = image[:, :, np.newaxis] # HW => HWC
193
         image = image.transpose(2, 0, 1) # HWC => CHW
194
-        return image
195
+        # handle image with sky mask dimensions
196
+        image_gray = np.array(image_gray)
197
+        if image_gray.ndim == 2:
198
+            image_gray = image_gray[:, :, np.newaxis] # HW => HWC
199
+        image_gray = image_gray.transpose(2, 0, 1) # HWC => CHW
200
+        # handle disp dimensions
201
+        disp = np.array(disp_img)
202
+        if disp.ndim == 2:
203
+            disp = disp[:, :, np.newaxis] # HW => HWC
204
+        disp = disp.transpose(2, 0, 1) # HWC => CHW
205
+        # handle sky mask dimensions
206
+        mask = np.array(sky_img)
207
+        if mask.ndim == 2:
208
+            mask = mask[:, :, np.newaxis] # HW => HWC
209
+        mask = mask.transpose(2, 0, 1) # HWC => CHW
210
+
211
+        # process mask
212
+        mask = mask / 255
213
+
214
+        # process disparity map (clip and rescale, to match nerf far)
215
+        disp = disp / 255 # convert back to [0, 1] range
216
+        disp_clipped = np.clip(disp, 1/self.depth_clip, 1) # range: [1/clip, 1]
217
+        psuedo_depth = 1/disp_clipped - 1 # range:[0, clip-1]
218
+        max_depth = self.depth_clip - 1
219
+        scaled_depth = psuedo_depth / max_depth * (self.depth_scale - 1) # range: [0, depth_scale-1]
220
+        scaled_disp = 1/(scaled_depth+1) # range: [1/depth_scale, 1]
221
+
222
+        # multiply everything by the downsampled mask
223
+        scaled_disp = scaled_disp * mask
224
+        scaled_depth = scaled_depth * mask
225
+        gray = np.array([128, 128, 128]).reshape(-1, 1, 1)
226
+        rgb_masked = (image_gray * mask + gray * (1-mask)).astype(np.uint8)
227
+
228
+        # intrinsics
229
+        K = np.zeros((3, 3))
230
+        fov = self.fov_mean + self.fov_std * np.random.randn()
231
+
232
+        fx = (self.image_shape[2] / 2) / np.tan(np.deg2rad(fov) / 2)
233
+        fy = (self.image_shape[1] / 2) / np.tan(np.deg2rad(fov) / 2)
234
+        K[0, 0] = fx
235
+        K[1, 1] = -fy
236
+        K[2, 2] = -1
237
+
238
+        # extrinsics
239
+        if self.pose_path is not None:
240
+            idx = random.randint(0, self.Rt.shape[0]-1)
241
+            Rt = self.Rt[idx].astype(np.float64)[0]
242
+            camera = self.cameras[idx]
243
+        else:
244
+            Rt = np.eye(4)
245
+            camera = camera_util.Camera(0., 0., 0., 0., 0.)
246
+
247
+        return {'rgb': rgb_masked,
248
+                'depth': scaled_disp if self.use_disp else scaled_depth,
249
+                'acc': mask,
250
+                'K': K, # 3x3
251
+                'Rt': Rt, #4x4
252
+                'global_size': self.image_shape[-1],
253
+                'fov': fov,
254
+                'orig': image,
255
+               }
256
 
257
     def _load_raw_labels(self):
258
         fname = 'dataset.json'
259

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.