google-research
124 строки · 5.5 Кб
1--- external/eg3d/training/dataset.py 2023-04-06 03:55:39.232777639 +0000
2+++ external_reference/eg3d/training/dataset.py 2023-04-06 03:41:04.542718663 +0000
3@@ -23,6 +23,8 @@
4except ImportError:
5pyspng = None
6
7+from utils import midas
8+
9#----------------------------------------------------------------------------
10
11class Dataset(torch.utils.data.Dataset):
12@@ -91,7 +93,7 @@
13image = self._load_raw_image(self._raw_idx[idx])
14assert isinstance(image, np.ndarray)
15assert list(image.shape) == self.image_shape
16- assert image.dtype == np.uint8
17+ # assert image.dtype == np.uint8 # depth is float values
18if self._xflip[idx]:
19assert image.ndim == 3 # CHW
20image = image[:, :, ::-1]
21@@ -163,14 +165,33 @@
22def __init__(self,
23path, # Path to directory or zip.
24resolution = None, # Ensure specific resolution, None = highest available.
25+ depth_scale = 16, # scale factor for depth
26+ depth_clip = 20, # clip all depths above this value
27+ white_sky = False, # mask sky with white pixels if true
28**super_kwargs, # Additional arguments for the Dataset base class.
29):
30self._path = path
31self._zipfile = None
32+ self.depth_clip = depth_clip
33+ self.depth_scale = depth_scale
34+ self.white_sky = white_sky
35
36if os.path.isdir(self._path):
37self._type = 'dir'
38- self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
39+ # note: places cache within directory, it contains cache for
40+ # disp/seg images but those are filtered out in _image_fnames
41+ if os.path.isfile(self._path + '/cache.txt'):
42+ with open(self._path + '/cache.txt') as cache:
43+ self._all_fnames = set([line.strip() for line in cache])
44+ else:
45+ print("Walking dataset...")
46+ self._all_fnames = [os.path.relpath(os.path.join(root, fname), start=self._path)
47+ for root, _dirs, files in
48+ os.walk(self._path, followlinks=True) for fname in files]
49+ with open(self._path + '/cache.txt', 'w') as cache:
50+ [cache.write("%s\n" % fname) for fname in self._all_fnames]
51+ self._all_fnames = set(self._all_fnames)
52+ print("done walking")
53elif self._file_ext(self._path) == '.zip':
54self._type = 'zip'
55self._all_fnames = set(self._get_zipfile().namelist())
56@@ -216,16 +237,63 @@
57return dict(super().__getstate__(), _zipfile=None)
58
59def _load_raw_image(self, raw_idx):
60+ ### modified to return RGBD image
61+ # and flip disp and sky mask if "mirror" in image path
62+
63fname = self._image_fnames[raw_idx]
64+ ### load image
65with self._open_file(fname) as f:
66- if pyspng is not None and self._file_ext(fname) == '.png':
67- image = pyspng.load(f.read())
68- else:
69- image = np.array(PIL.Image.open(f))
70+ image = np.array(PIL.Image.open(f))
71+ w, h, _ = image.shape
72+ ### load depth map
73+ depth_path = (os.path.join(self._path, fname)
74+ .replace('png', 'pfm')
75+ .replace('img', 'disp')
76+ .replace('_mirror', ''))
77+ disp, scale = midas.read_pfm(depth_path)
78+ # normalize 0 to 1
79+ disp = np.array(disp)
80+ dmmin = np.percentile(disp, 1)
81+ dmmax = np.percentile(disp, 99)
82+ scaled_disp = (disp-dmmin) / (dmmax-dmmin + 1e-6)
83+ scaled_disp = np.clip(scaled_disp, 0., 1.) * 255
84+ disp_img = PIL.Image.fromarray(scaled_disp.astype(np.uint8))
85+ if 'mirror' in fname:
86+ scaled_disp = np.fliplr(scaled_disp)
87+ disp_img = disp_img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
88+
89+ ### load sky mask
90+ sky_path = (os.path.join(self._path, fname)
91+ .replace('png', 'npz')
92+ .replace('img', 'seg')
93+ .replace('_mirror', ''))
94+ sky_mask = np.load(sky_path)['sky_mask']
95+ sky_img = PIL.Image.fromarray(sky_mask * 255)
96+ if 'mirror' in fname:
97+ sky_mask = np.fliplr(sky_mask)
98+ sky_img = sky_img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
99+
100+ # process image
101if image.ndim == 2:
102image = image[:, :, np.newaxis] # HW => HWC
103image = image.transpose(2, 0, 1) # HWC => CHW
104- return image
105+
106+ # process disparity map
107+ disp = scaled_disp / 255 # convert back to [0, 1] range
108+ disp_clipped = np.clip(disp, 1/self.depth_clip, 1) # range: [1/clip, 1]
109+ psuedo_depth = 1/disp_clipped - 1 # range:[0, clip-1]
110+ max_depth = self.depth_clip - 1
111+ scaled_depth = psuedo_depth / max_depth * (self.depth_scale - 1) # range: [0, depth_scale-1]
112+ scaled_disp = 1/(scaled_depth+1) # range: [1/depth_scale, 1]
113+
114+ # multiply everything by sky mask
115+ scaled_disp = scaled_disp * sky_mask
116+
117+ if self.white_sky:
118+ sky_color = np.array([255, 255, 255]).reshape(-1, 1, 1)
119+ image = (image * sky_mask[None] + sky_color * (1-sky_mask[None]))
120+
121+ return np.concatenate([image, scaled_disp[None]], axis=0)
122
123def _load_raw_labels(self):
124fname = 'dataset.json'
125