google-research

Форк
0
124 строки · 5.5 Кб
1
--- external/eg3d/training/dataset.py	2023-04-06 03:55:39.232777639 +0000
2
+++ external_reference/eg3d/training/dataset.py	2023-04-06 03:41:04.542718663 +0000
3
@@ -23,6 +23,8 @@
4
 except ImportError:
5
     pyspng = None
6
 
7
+from utils import midas
8
+
9
 #----------------------------------------------------------------------------
10
 
11
 class Dataset(torch.utils.data.Dataset):
12
@@ -91,7 +93,7 @@
13
         image = self._load_raw_image(self._raw_idx[idx])
14
         assert isinstance(image, np.ndarray)
15
         assert list(image.shape) == self.image_shape
16
-        assert image.dtype == np.uint8
17
+        # assert image.dtype == np.uint8 # depth is float values
18
         if self._xflip[idx]:
19
             assert image.ndim == 3 # CHW
20
             image = image[:, :, ::-1]
21
@@ -163,14 +165,33 @@
22
     def __init__(self,
23
         path,                   # Path to directory or zip.
24
         resolution      = None, # Ensure specific resolution, None = highest available.
25
+        depth_scale     = 16,   # scale factor for depth
26
+        depth_clip      = 20,   # clip all depths above this value
27
+        white_sky       = False, # mask sky with white pixels if true
28
         **super_kwargs,         # Additional arguments for the Dataset base class.
29
     ):
30
         self._path = path
31
         self._zipfile = None
32
+        self.depth_clip = depth_clip
33
+        self.depth_scale = depth_scale
34
+        self.white_sky = white_sky
35
 
36
         if os.path.isdir(self._path):
37
             self._type = 'dir'
38
-            self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
39
+            # note: places cache within directory, it contains cache for
40
+            # disp/seg images but those are filtered out in _image_fnames
41
+            if os.path.isfile(self._path + '/cache.txt'):
42
+                with open(self._path + '/cache.txt') as cache:
43
+                    self._all_fnames = set([line.strip() for line in cache])
44
+            else:
45
+                print("Walking dataset...")
46
+                self._all_fnames = [os.path.relpath(os.path.join(root, fname), start=self._path)
47
+                                    for root, _dirs, files in
48
+                                    os.walk(self._path, followlinks=True) for fname in files]
49
+                with open(self._path + '/cache.txt', 'w') as cache:
50
+                    [cache.write("%s\n" % fname) for fname in self._all_fnames]
51
+                self._all_fnames = set(self._all_fnames)
52
+                print("done walking")
53
         elif self._file_ext(self._path) == '.zip':
54
             self._type = 'zip'
55
             self._all_fnames = set(self._get_zipfile().namelist())
56
@@ -216,16 +237,63 @@
57
         return dict(super().__getstate__(), _zipfile=None)
58
 
59
     def _load_raw_image(self, raw_idx):
60
+        ### modified to return RGBD image
61
+        # and flip disp and sky mask if "mirror" in image path
62
+
63
         fname = self._image_fnames[raw_idx]
64
+        ### load image
65
         with self._open_file(fname) as f:
66
-            if pyspng is not None and self._file_ext(fname) == '.png':
67
-                image = pyspng.load(f.read())
68
-            else:
69
-                image = np.array(PIL.Image.open(f))
70
+            image = np.array(PIL.Image.open(f))
71
+        w, h, _ = image.shape
72
+        ### load depth map
73
+        depth_path = (os.path.join(self._path, fname)
74
+                      .replace('png', 'pfm')
75
+                      .replace('img', 'disp')
76
+                      .replace('_mirror', ''))
77
+        disp, scale = midas.read_pfm(depth_path)
78
+        # normalize 0 to 1
79
+        disp = np.array(disp)
80
+        dmmin = np.percentile(disp, 1)
81
+        dmmax = np.percentile(disp, 99)
82
+        scaled_disp = (disp-dmmin) / (dmmax-dmmin + 1e-6)
83
+        scaled_disp = np.clip(scaled_disp, 0., 1.) * 255
84
+        disp_img = PIL.Image.fromarray(scaled_disp.astype(np.uint8))
85
+        if 'mirror' in fname:
86
+            scaled_disp = np.fliplr(scaled_disp)
87
+            disp_img = disp_img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
88
+
89
+        ### load sky mask
90
+        sky_path = (os.path.join(self._path, fname)
91
+                    .replace('png', 'npz')
92
+                    .replace('img', 'seg')
93
+                    .replace('_mirror', ''))
94
+        sky_mask = np.load(sky_path)['sky_mask']
95
+        sky_img = PIL.Image.fromarray(sky_mask * 255)
96
+        if 'mirror' in fname:
97
+            sky_mask = np.fliplr(sky_mask)
98
+            sky_img = sky_img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
99
+
100
+        # process image
101
         if image.ndim == 2:
102
             image = image[:, :, np.newaxis] # HW => HWC
103
         image = image.transpose(2, 0, 1) # HWC => CHW
104
-        return image
105
+
106
+        # process disparity map
107
+        disp = scaled_disp / 255 # convert back to [0, 1] range
108
+        disp_clipped = np.clip(disp, 1/self.depth_clip, 1) # range: [1/clip, 1]
109
+        psuedo_depth = 1/disp_clipped - 1 # range:[0, clip-1]
110
+        max_depth = self.depth_clip - 1
111
+        scaled_depth = psuedo_depth / max_depth * (self.depth_scale - 1) # range: [0, depth_scale-1]
112
+        scaled_disp = 1/(scaled_depth+1) # range: [1/depth_scale, 1]
113
+
114
+        # multiply everything by sky mask
115
+        scaled_disp = scaled_disp * sky_mask
116
+
117
+        if self.white_sky:
118
+            sky_color = np.array([255, 255, 255]).reshape(-1, 1, 1)
119
+            image = (image * sky_mask[None] + sky_color * (1-sky_mask[None]))
120
+
121
+        return np.concatenate([image, scaled_disp[None]], axis=0)
122
 
123
     def _load_raw_labels(self):
124
         fname = 'dataset.json'
125

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.