google-research
106 строк · 3.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Generate training cameras.
17
18randomly sample camera origin
19then compute valid rotation to overlap with layout
20"""
21import os
22
23from external.gsn.models.nerf_utils import get_sample_points
24import matplotlib.pyplot as plt
25import numpy as np
26import torch
27from utils import camera_util
28
29plane_width = 256 * 0.15 # 38.4
30nerf_far = 16
31ymax = 0.5
32seed = 0
33rng = np.random.RandomState(seed)
34
35sampled_Rts = []
36sampled_cameras = []
37for i in range(1000):
38if i % 100 == 0:
39print(i)
40
41Tx = rng.rand() * plane_width - plane_width / 2
42Tz = rng.rand() * plane_width - plane_width / 2
43Ty = rng.randn() * ymax / 3
44
45valid_degrees = []
46# find the rotations that are valid
47for degree in range(360):
48# compute world2cam matrix
49camera = camera_util.Camera(Tx, Ty, Tz, degree, 0.0)
50Rt = camera_util.pose_from_camera(camera)[None]
51# convert to cam2world matrix
52# (used FOV=90, to reproduce previous pose distribution)
53xyz, viewdirs, zvals, rd, ro = get_sample_points(
54tform_cam2world=Rt.inverse(),
55F=(16, -16),
56H=1,
57W=32,
58samples_per_ray=2,
59near=0,
60far=nerf_far / 2,
61perturb=False,
62mask=None,
63)
64if np.all(np.abs(xyz).numpy() < plane_width / 2):
65valid_degrees.append(degree)
66# sample from the valid degrees
67degree = rng.choice(valid_degrees) + (rng.rand() - 0.5)
68camera = camera_util.Camera(Tx, Ty, Tz, degree, 0.0)
69Rt = camera_util.pose_from_camera(camera)[None]
70
71# store world2cam transformation
72sampled_Rts.append(Rt)
73sampled_cameras.append(camera)
74
75f, ax = plt.subplots()
76for i in np.random.choice(len(sampled_Rts), 500):
77Rt = sampled_Rts[i]
78xyz, viewdirs, zvals, rd, ro = get_sample_points(
79tform_cam2world=Rt.inverse(),
80F=(16, -16),
81H=1,
82W=1,
83samples_per_ray=64,
84near=0,
85far=8,
86perturb=False,
87mask=None,
88)
89ax.scatter(xyz[0, 0, 0, 0], xyz[0, 0, 0, 2])
90ax.arrow(
91xyz[0, 0, 0, 0],
92xyz[0, 0, 0, 2],
93xyz[0, 0, 20, 0] - xyz[0, 0, 0, 0],
94xyz[0, 0, 20, 2] - xyz[0, 0, 0, 2],
95)
96ax.set_xlim([-plane_width / 2, plane_width / 2])
97ax.set_ylim([-plane_width / 2, plane_width / 2])
98ax.set_aspect('equal', adjustable='box')
99f.savefig('preprocessing/poses.jpg')
100
101# save with noisy camera heights
102os.makedirs('poses', exist_ok=True)
103torch.save(
104{'Rts': torch.stack(sampled_Rts), 'cameras': sampled_cameras},
105f'./poses/width{plane_width}_far{nerf_far}_noisy_height.pth',
106)
107