google-research

Форк
0
113 строк · 4.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Removes LHQ images unsuitable for flying."""
17
import os
18

19
import cv2 as cv
20
import numpy as np
21
from PIL import Image
22
from tqdm import tqdm
23
from utils import midas
24

25
lhq = 'dataset/LHQ'  # source directory
26
# assumes source directory contains:
27
# 1) directory lhq_256 of LHQ images at size 256x256
28
# 2) directory dpt_depth of DPT depth result on lhq_256 (default setting)
29
# 2) directory dpt_seg of DPT segmentation result on lhq_256 (default setting)
30
lhq = os.path.abspath(lhq)  # get abspath for symlink
31
output_dir = 'dataset/lhq_processed'
32
os.makedirs(os.path.join(output_dir, 'img'), exist_ok=True)
33
os.makedirs(os.path.join(output_dir, 'dpt_depth'), exist_ok=True)
34
os.makedirs(os.path.join(output_dir, 'dpt_depth-vis'), exist_ok=True)
35
os.makedirs(os.path.join(output_dir, 'dpt_sky'), exist_ok=True)
36
os.makedirs(os.path.join(lhq, 'dpt_sky_output'), exist_ok=True)
37

38
counter = 0
39
total = 90000
40

41
for c, i in enumerate(tqdm(range(total))):  # 90000))):
42
  path = os.path.join(lhq, 'dpt_depth', '%07d.pfm' % i)
43
  disp, scale = midas.read_pfm(path)
44

45
  # normalize the disparity
46
  max_disp = np.max(disp)
47
  min_disp = np.min(disp)
48
  disp_norm = disp / np.max(disp)
49

50
  seg = np.array(Image.open(os.path.join(lhq, 'dpt_seg', '%07d_seg.png' % i)))
51
  ### remove small contours
52
  # sky_mask is 0 in the sky region
53
  sky_mask = 1 - (seg == 3).astype(np.uint8)
54
  contours, hierarchy = cv.findContours(sky_mask * 255, 1, 2)
55
  processing_mask = np.ones_like(sky_mask)
56
  # processing mask is zero for small contours; one otherwise
57
  for j, cnt in enumerate(contours):
58
    area = cv.contourArea(cnt)
59
    if area < 250:
60
      cv.drawContours(processing_mask, contours, j, (0, 0, 0), cv.FILLED)
61
  # zero out small contours, and save the intermediate result
62
  sky_mask_processed = processing_mask * sky_mask
63
  np.savez(
64
      os.path.join(lhq, 'dpt_sky_output', '%07d.npz' % i),
65
      sky_mask=sky_mask_processed,
66
  )
67
  contours_processed, _ = cv.findContours(sky_mask_processed * 255, 1, 2)
68

69
  keep_img = True
70
  # too many contours --> trees
71
  if len(contours_processed) > 3:
72
    keep_img = False
73
  # check sky region --> too much non-sky region
74
  if np.mean(sky_mask_processed) > 0.9:
75
    keep_img = False
76
  # check upper part of scene
77
  h, w = sky_mask_processed.shape
78
  upper = sky_mask_processed[: h // 5, :]
79
  if np.mean(upper) > 0.4:
80
    # too much foreground in upper part of image
81
    keep_img = False
82
  # check lower part of scene
83
  lower = sky_mask_processed[-h // 4 :, :]
84
  if np.mean(lower) < 0.8:
85
    # too much sky in lower part of image
86
    keep_img = False
87
  # check not too many vertical edges
88
  if np.percentile(np.abs(disp_norm[:, 1:] - disp_norm[:, :-1]), 99) > 0.05:
89
    keep_img = False
90

91
  if keep_img:
92
    cmd = 'ln -s %s %s' % (
93
        os.path.join(lhq, 'dpt_depth', '%07d.pfm' % i),
94
        os.path.join(output_dir, 'dpt_depth'),
95
    )
96
    os.system(cmd)
97
    cmd = 'ln -s %s %s' % (
98
        os.path.join(lhq, 'lhq_256', '%07d.png' % i),
99
        os.path.join(output_dir, 'img'),
100
    )
101
    os.system(cmd)
102
    cmd = 'ln -s %s %s' % (
103
        os.path.join(lhq, 'dpt_sky_output', '%07d.npz' % i),
104
        os.path.join(output_dir, 'dpt_sky'),
105
    )
106
    os.system(cmd)
107
    counter += 1
108
    if c < 5000:
109
      disp_out = (disp - np.min(disp)) / (np.max(disp) - np.min(disp))
110
      im = Image.fromarray((disp_out * 255).astype(np.uint8))
111
      im.save(os.path.join(output_dir, 'dpt_depth-vis', '%07d.png' % i))
112

113
print('kept %d images %0.2f' % (counter, counter / total))
114

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.