google-research
409 строк · 10.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: disable=invalid-name,g-importing-member,g-multiple-import
17"""Utilities for camera manipulation."""
18from collections import namedtuple19from math import cos, sin, pi20import random21
22import numpy as np23import torch24
25
26####### camera utils
27
28# Tuple to represent user camera position
29Camera = namedtuple(30'Camera',31[32'x',33'y',34'z', # position35'theta', # horizontal direction to look, in degrees. (0 = positive x)36'psi', # up/down angle, in degrees (0 = level)37],38)
39
40
41def initial_camera():42return Camera(0.0, 0.0, 0.0, 0.0, 0.0)43
44
45# Camera movement constants
46ROTATION_HORIZONTAL_DEGREES = 547ROTATION_UPDOWN_DEGREES = 548UPDOWN_MIN = -9049UPDOWN_MAX = 9050FORWARD_SPEED = 1 / 251SIDEWAYS_SPEED = FORWARD_SPEED / 252VERTICAL_SPEED = FORWARD_SPEED / 253INITIAL_CAMERA = None54
55
56def pose_from_camera(camera):57"""A 4x4 pose matrix mapping world to camera space.58
59Args:
60camera: camera object
61
62Returns:
63world2cam matrix
64"""
65cos_theta = cos((camera.theta + 90) * pi / 180)66sin_theta = sin((camera.theta + 90) * pi / 180)67cos_psi = cos(camera.psi * pi / 180)68sin_psi = sin(camera.psi * pi / 180)69Ry = torch.tensor([70[cos_theta, 0, sin_theta, 0],71[0, 1, 0, 0],72[-sin_theta, 0, cos_theta, 0],73[0, 0, 0, 1],74])75Rx = torch.tensor([76[1, 0, 0, 0],77[0, cos_psi, sin_psi, 0],78[0, -sin_psi, cos_psi, 0],79[0, 0, 0, 1],80])81T = torch.tensor([82[1, 0, 0, -camera.x],83[0, 1, 0, -camera.y],84[0, 0, 1, -camera.z],85[0, 0, 0, 1],86])87return torch.mm(torch.mm(Rx, Ry), T)88
89
90def camera_from_pose(Rt):91"""Solve for camera variables from world2cam pose.92
93Args:
94Rt: 4x4 torch.Tensor, world2cam pose
95
96Returns:
97camera object
98"""
99assert list(Rt.shape) == [4, 4]100
101# solve for theta102cos_theta = Rt[0, 0] # x103sin_theta = Rt[0, 2] # y104theta = torch.atan2(sin_theta, cos_theta) # y, x105theta = theta * 180 / pi # convert to deg106theta = (theta - 90) % 360 # 90 degree rotation107
108# solve for psi109cos_psi = Rt[1, 1]110sin_psi = -Rt[2, 1]111psi = torch.atan(sin_psi / cos_psi)112psi = psi * 180 / pi113
114# Rx @ Ry115R = pose_from_camera(Camera(0.0, 0.0, 0.0, theta.item(), psi.item()))116T = torch.mm(R.inverse(), Rt.cpu())117camera = Camera(118-T[0, 3].item(),119-T[1, 3].item(),120-T[2, 3].item(),121theta.item(),122psi.item(),123)124return camera125
126
127def get_full_image_parameters(128layout_model,129nerf_render_size,130batch_size,131device='cuda',132Rt=None,133sample_fov=False,134):135"""Construct intrisics for image of size nerf_render_size."""136camera_params = {}137if sample_fov:138fov = layout_model.fov_mean + layout_model.fov_std * np.random.randn(139batch_size
140)141else:142# use the mean FOV rather than sampling143fov = layout_model.fov_mean + 0.0 * np.random.randn(batch_size)144
145sampled_size = np.array([nerf_render_size] * batch_size)146focal = (sampled_size / 2) / np.tan(np.deg2rad(fov) / 2)147K = np.zeros((batch_size, 3, 3))148K[:, 0, 0] = focal149K[:, 1, 1] = -focal150K[:, 2, 2] = -1 # Bx3x3151K = torch.from_numpy(K).float().to(device)152
153camera_params['K'] = K154camera_params['global_size'] = torch.from_numpy(sampled_size).float()155camera_params['fov'] = torch.from_numpy(fov).float()156
157if Rt is not None:158if Rt.ndim == 4:159assert Rt.shape[1] == 1160Rt = Rt[:, 0, :, :]161camera_params['Rt'] = Rt # Bx4x4162return camera_params163
164
165# --------------------------------------------------------------------
166# camera motion utils
167
168
169def update_camera(camera, key, auto_adjust_height_and_tilt=True):170"""move camera according to key pressed."""171if key == 'x':172# Reset173if INITIAL_CAMERA is not None:174return INITIAL_CAMERA175return initial_camera() # camera at origin176
177if auto_adjust_height_and_tilt:178# ignore additional controls179if key in ['r', 'f', 't', 'g']:180return camera181
182x = camera.x183y = camera.y184z = camera.z185theta = camera.theta186psi = camera.psi187cos_theta = cos(theta * pi / 180)188sin_theta = sin(theta * pi / 180)189
190# Rotation left and right191if key == 'a':192theta -= ROTATION_HORIZONTAL_DEGREES193if key == 'd':194theta += ROTATION_HORIZONTAL_DEGREES195theta = theta % 360196
197# Looking up and down198if key == 'r':199psi += ROTATION_UPDOWN_DEGREES200if key == 'f':201psi -= ROTATION_UPDOWN_DEGREES202psi = max(UPDOWN_MIN, min(UPDOWN_MAX, psi))203
204# Movement in 3 dimensions205if key == 'w':206# Go forward207x += cos_theta * FORWARD_SPEED208z += sin_theta * FORWARD_SPEED209if key == 's':210# Go backward211x -= cos_theta * FORWARD_SPEED212z -= sin_theta * FORWARD_SPEED213if key == 'q':214# Move left215x -= -sin_theta * SIDEWAYS_SPEED216z -= cos_theta * SIDEWAYS_SPEED217if key == 'e':218# Move right219x += -sin_theta * SIDEWAYS_SPEED220z += cos_theta * SIDEWAYS_SPEED221if key == 't':222# Move up223y += VERTICAL_SPEED224if key == 'g':225# Move down226y -= VERTICAL_SPEED227return Camera(x, y, z, theta, psi)228
229
230def move_camera(camera, forward_speed, rotation_speed):231x = camera.x232y = camera.y233z = camera.z234theta = camera.theta + rotation_speed235psi = camera.psi236cos_theta = cos(theta * pi / 180)237sin_theta = sin(theta * pi / 180)238x += cos_theta * forward_speed239z += sin_theta * forward_speed240return Camera(x, y, z, theta, psi)241
242
243# --------------------------------------------------------------------
244# camera balancing utils
245
246# How far up the image should the horizon be, ideally.
247# Suggested range: 0.5 to 0.7.
248horizon_target = 0.65249
250# What proportion of the depth map should be "near" the camera, ideally.
251# The smaller the number, the higher up the camera will fly.
252# Suggested range: 0.05 to 0.2
253near_target = 0.2254
255tilt_velocity_scale = 0.3256offset_velocity_scale = 0.5257
258
259def land_fraction(sky_mask):260return torch.mean(sky_mask).item()261
262
263def near_fraction(depth, near_depth=0.3, near_spread=0.1):264near = torch.clip((depth - near_depth) / near_spread, 0.0, 1.0)265return torch.mean(near).item()266
267
268def adjust_camera_vertically(camera, offset, tilt):269return Camera(270camera.x, camera.y + offset, camera.z, camera.theta, camera.psi + tilt271)272
273
274# layout model: adjust tilt and offset parameters based
275# on near and land fraction
276def update_tilt_and_offset(277outputs,278tilt,279offset,280horizon_target=horizon_target,281near_target=near_target,282tilt_velocity_scale=tilt_velocity_scale,283offset_velocity_scale=offset_velocity_scale,284): # pylint: disable=redefined-outer-name285"""Adjust tilt and offest based on geometry."""286depth = (287outputs['depth_up'][0]288if outputs['depth_up'] is not None289else outputs['depth_thumb']290)291sky_mask = outputs['sky_mask'][0]292horizon = land_fraction(sky_mask)293near = near_fraction(depth)294tilt += tilt_velocity_scale * (horizon - horizon_target)295offset += offset_velocity_scale * (near - near_target)296return tilt, offset297
298
299# --------------------------------------------------------------------
300# camera interpolation utils
301
302
303# Interpolate between random points
304def interpolate_camera(start, end, l):305def i(a, b):306return b * l + a * (1 - l)307
308end_theta = end.theta309if end.theta - start.theta > 180:310end_theta -= 360311if start.theta - end.theta > 180:312end_theta += 360313return Camera(314i(start.x, end.x),315i(start.y, end.y),316i(start.z, end.z),317i(start.theta, end_theta),318i(start.psi, end.psi),319)320
321
322def ease(x):323if x < 0.5:324return 2 * x * x325return 1 - 2 * (1 - x) * (1 - x)326
327
328def lerp(a, b, l):329return a * (1 - l) + b * l330
331
332def random_camera(tlim=16, psi_multiplier=20):333height = random.uniform(0, 2)334psi = -psi_multiplier * height335return Camera(336random.uniform(-tlim, tlim),337height,338random.uniform(-tlim, tlim),339random.uniform(0, 360),340psi,341)342
343
344def visualize_rays(G_terrain, Rt, xyz, layout, display_size, cam_grid=None):345"""Return an image showing the camera rays projected onto X-Z plane."""346# Rt = world2cam matrix347
348if hasattr(G_terrain, 'layout_generator'):349# layout model350global_feat_res = G_terrain.layout_decoder.global_feat_res351coordinate_scale = G_terrain.coordinate_scale352else:353# triplane model354global_feat_res = G_terrain.backbone_xz.img_resolution355coordinate_scale = G_terrain.rendering_kwargs['box_warp']356
357inference_feat_res = layout.shape[-1]358
359# compute pixel locations for camera points360cam_frustum = xyz / (coordinate_scale / 2) # normalize to [-1, 1]361cam_frustum = (362cam_frustum * global_feat_res / inference_feat_res363) # rescale for extended spatial grid364cam_frustum = (cam_frustum + 1) / 2 # normalize to [0, 1]365cam_frustum = (366(cam_frustum * display_size).long().clamp(0, display_size - 1)367) # convert to [0, display size]368
369# compute pixel locations for camera center370tform_cam2world = Rt.inverse()371cam_center = tform_cam2world[0, :3, -1]372cam_center = cam_center / (coordinate_scale / 2)373cam_center = (374cam_center * global_feat_res / inference_feat_res375) # rescale for extended spatial grid376cam_center = (cam_center + 1) / 2377cam_center = (cam_center * display_size).long().clamp(0, display_size - 1)378
379# compute pixel locations for white box representing training region380orig_grid_offset = torch.Tensor([381-1,382]) * (global_feat_res / inference_feat_res)383orig_grid_offset = (orig_grid_offset + 1) / 2384orig_grid_offset = (385(orig_grid_offset * display_size).long().clamp(0, display_size - 1)386) # convert to [0, display size]387
388if cam_grid is None:389cam_grid = torch.zeros(3, display_size, display_size)390else:391cam_grid = cam_grid.clone().cpu()392
393# plot everything on image394cam_grid[3951, cam_frustum[Ellipsis, 2].reshape(-1), cam_frustum[Ellipsis, 0].reshape(-1)396] = 1397cam_grid[3980, cam_frustum[Ellipsis, 2].reshape(-1), cam_frustum[Ellipsis, 0].reshape(-1)399] = 0.5400cam_grid[:, orig_grid_offset, orig_grid_offset:-orig_grid_offset] = 1401cam_grid[:, -orig_grid_offset, orig_grid_offset:-orig_grid_offset] = 1402cam_grid[:, orig_grid_offset:-orig_grid_offset, orig_grid_offset] = 1403cam_grid[:, orig_grid_offset:-orig_grid_offset, -orig_grid_offset] = 1404cam_grid[4050,406cam_center[2] - 2 : cam_center[2] + 2,407cam_center[0] - 2 : cam_center[0] + 2,408] = 1409return cam_grid410