google-research

Форк
0
409 строк · 10.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: disable=invalid-name,g-importing-member,g-multiple-import
17
"""Utilities for camera manipulation."""
18
from collections import namedtuple
19
from math import cos, sin, pi
20
import random
21

22
import numpy as np
23
import torch
24

25

26
####### camera utils
27

28
# Tuple to represent user camera position
29
Camera = namedtuple(
30
    'Camera',
31
    [
32
        'x',
33
        'y',
34
        'z',  # position
35
        'theta',  # horizontal direction to look, in degrees. (0 = positive x)
36
        'psi',  # up/down angle, in degrees (0 = level)
37
    ],
38
)
39

40

41
def initial_camera():
42
  return Camera(0.0, 0.0, 0.0, 0.0, 0.0)
43

44

45
# Camera movement constants
46
ROTATION_HORIZONTAL_DEGREES = 5
47
ROTATION_UPDOWN_DEGREES = 5
48
UPDOWN_MIN = -90
49
UPDOWN_MAX = 90
50
FORWARD_SPEED = 1 / 2
51
SIDEWAYS_SPEED = FORWARD_SPEED / 2
52
VERTICAL_SPEED = FORWARD_SPEED / 2
53
INITIAL_CAMERA = None
54

55

56
def pose_from_camera(camera):
57
  """A 4x4 pose matrix mapping world to camera space.
58

59
  Args:
60
    camera: camera object
61

62
  Returns:
63
    world2cam matrix
64
  """
65
  cos_theta = cos((camera.theta + 90) * pi / 180)
66
  sin_theta = sin((camera.theta + 90) * pi / 180)
67
  cos_psi = cos(camera.psi * pi / 180)
68
  sin_psi = sin(camera.psi * pi / 180)
69
  Ry = torch.tensor([
70
      [cos_theta, 0, sin_theta, 0],
71
      [0, 1, 0, 0],
72
      [-sin_theta, 0, cos_theta, 0],
73
      [0, 0, 0, 1],
74
  ])
75
  Rx = torch.tensor([
76
      [1, 0, 0, 0],
77
      [0, cos_psi, sin_psi, 0],
78
      [0, -sin_psi, cos_psi, 0],
79
      [0, 0, 0, 1],
80
  ])
81
  T = torch.tensor([
82
      [1, 0, 0, -camera.x],
83
      [0, 1, 0, -camera.y],
84
      [0, 0, 1, -camera.z],
85
      [0, 0, 0, 1],
86
  ])
87
  return torch.mm(torch.mm(Rx, Ry), T)
88

89

90
def camera_from_pose(Rt):
91
  """Solve for camera variables from world2cam pose.
92

93
  Args:
94
    Rt: 4x4 torch.Tensor, world2cam pose
95

96
  Returns:
97
    camera object
98
  """
99
  assert list(Rt.shape) == [4, 4]
100

101
  # solve for theta
102
  cos_theta = Rt[0, 0]  # x
103
  sin_theta = Rt[0, 2]  # y
104
  theta = torch.atan2(sin_theta, cos_theta)  # y, x
105
  theta = theta * 180 / pi  # convert to deg
106
  theta = (theta - 90) % 360  # 90 degree rotation
107

108
  # solve for psi
109
  cos_psi = Rt[1, 1]
110
  sin_psi = -Rt[2, 1]
111
  psi = torch.atan(sin_psi / cos_psi)
112
  psi = psi * 180 / pi
113

114
  # Rx @ Ry
115
  R = pose_from_camera(Camera(0.0, 0.0, 0.0, theta.item(), psi.item()))
116
  T = torch.mm(R.inverse(), Rt.cpu())
117
  camera = Camera(
118
      -T[0, 3].item(),
119
      -T[1, 3].item(),
120
      -T[2, 3].item(),
121
      theta.item(),
122
      psi.item(),
123
  )
124
  return camera
125

126

127
def get_full_image_parameters(
128
    layout_model,
129
    nerf_render_size,
130
    batch_size,
131
    device='cuda',
132
    Rt=None,
133
    sample_fov=False,
134
):
135
  """Construct intrisics for image of size nerf_render_size."""
136
  camera_params = {}
137
  if sample_fov:
138
    fov = layout_model.fov_mean + layout_model.fov_std * np.random.randn(
139
        batch_size
140
    )
141
  else:
142
    # use the mean FOV rather than sampling
143
    fov = layout_model.fov_mean + 0.0 * np.random.randn(batch_size)
144

145
  sampled_size = np.array([nerf_render_size] * batch_size)
146
  focal = (sampled_size / 2) / np.tan(np.deg2rad(fov) / 2)
147
  K = np.zeros((batch_size, 3, 3))
148
  K[:, 0, 0] = focal
149
  K[:, 1, 1] = -focal
150
  K[:, 2, 2] = -1  # Bx3x3
151
  K = torch.from_numpy(K).float().to(device)
152

153
  camera_params['K'] = K
154
  camera_params['global_size'] = torch.from_numpy(sampled_size).float()
155
  camera_params['fov'] = torch.from_numpy(fov).float()
156

157
  if Rt is not None:
158
    if Rt.ndim == 4:
159
      assert Rt.shape[1] == 1
160
      Rt = Rt[:, 0, :, :]
161
    camera_params['Rt'] = Rt  # Bx4x4
162
  return camera_params
163

164

165
# --------------------------------------------------------------------
166
# camera motion utils
167

168

169
def update_camera(camera, key, auto_adjust_height_and_tilt=True):
170
  """move camera according to key pressed."""
171
  if key == 'x':
172
    # Reset
173
    if INITIAL_CAMERA is not None:
174
      return INITIAL_CAMERA
175
    return initial_camera()  # camera at origin
176

177
  if auto_adjust_height_and_tilt:
178
    # ignore additional controls
179
    if key in ['r', 'f', 't', 'g']:
180
      return camera
181

182
  x = camera.x
183
  y = camera.y
184
  z = camera.z
185
  theta = camera.theta
186
  psi = camera.psi
187
  cos_theta = cos(theta * pi / 180)
188
  sin_theta = sin(theta * pi / 180)
189

190
  # Rotation left and right
191
  if key == 'a':
192
    theta -= ROTATION_HORIZONTAL_DEGREES
193
  if key == 'd':
194
    theta += ROTATION_HORIZONTAL_DEGREES
195
  theta = theta % 360
196

197
  # Looking up and down
198
  if key == 'r':
199
    psi += ROTATION_UPDOWN_DEGREES
200
  if key == 'f':
201
    psi -= ROTATION_UPDOWN_DEGREES
202
  psi = max(UPDOWN_MIN, min(UPDOWN_MAX, psi))
203

204
  # Movement in 3 dimensions
205
  if key == 'w':
206
    # Go forward
207
    x += cos_theta * FORWARD_SPEED
208
    z += sin_theta * FORWARD_SPEED
209
  if key == 's':
210
    # Go backward
211
    x -= cos_theta * FORWARD_SPEED
212
    z -= sin_theta * FORWARD_SPEED
213
  if key == 'q':
214
    # Move left
215
    x -= -sin_theta * SIDEWAYS_SPEED
216
    z -= cos_theta * SIDEWAYS_SPEED
217
  if key == 'e':
218
    # Move right
219
    x += -sin_theta * SIDEWAYS_SPEED
220
    z += cos_theta * SIDEWAYS_SPEED
221
  if key == 't':
222
    # Move up
223
    y += VERTICAL_SPEED
224
  if key == 'g':
225
    # Move down
226
    y -= VERTICAL_SPEED
227
  return Camera(x, y, z, theta, psi)
228

229

230
def move_camera(camera, forward_speed, rotation_speed):
231
  x = camera.x
232
  y = camera.y
233
  z = camera.z
234
  theta = camera.theta + rotation_speed
235
  psi = camera.psi
236
  cos_theta = cos(theta * pi / 180)
237
  sin_theta = sin(theta * pi / 180)
238
  x += cos_theta * forward_speed
239
  z += sin_theta * forward_speed
240
  return Camera(x, y, z, theta, psi)
241

242

243
# --------------------------------------------------------------------
244
# camera balancing utils
245

246
# How far up the image should the horizon be, ideally.
247
# Suggested range: 0.5 to 0.7.
248
horizon_target = 0.65
249

250
# What proportion of the depth map should be "near" the camera, ideally.
251
# The smaller the number, the higher up the camera will fly.
252
# Suggested range: 0.05 to 0.2
253
near_target = 0.2
254

255
tilt_velocity_scale = 0.3
256
offset_velocity_scale = 0.5
257

258

259
def land_fraction(sky_mask):
260
  return torch.mean(sky_mask).item()
261

262

263
def near_fraction(depth, near_depth=0.3, near_spread=0.1):
264
  near = torch.clip((depth - near_depth) / near_spread, 0.0, 1.0)
265
  return torch.mean(near).item()
266

267

268
def adjust_camera_vertically(camera, offset, tilt):
269
  return Camera(
270
      camera.x, camera.y + offset, camera.z, camera.theta, camera.psi + tilt
271
  )
272

273

274
# layout model: adjust tilt and offset parameters based
275
# on near and land fraction
276
def update_tilt_and_offset(
277
    outputs,
278
    tilt,
279
    offset,
280
    horizon_target=horizon_target,
281
    near_target=near_target,
282
    tilt_velocity_scale=tilt_velocity_scale,
283
    offset_velocity_scale=offset_velocity_scale,
284
):  # pylint: disable=redefined-outer-name
285
  """Adjust tilt and offest based on geometry."""
286
  depth = (
287
      outputs['depth_up'][0]
288
      if outputs['depth_up'] is not None
289
      else outputs['depth_thumb']
290
  )
291
  sky_mask = outputs['sky_mask'][0]
292
  horizon = land_fraction(sky_mask)
293
  near = near_fraction(depth)
294
  tilt += tilt_velocity_scale * (horizon - horizon_target)
295
  offset += offset_velocity_scale * (near - near_target)
296
  return tilt, offset
297

298

299
# --------------------------------------------------------------------
300
# camera interpolation utils
301

302

303
# Interpolate between random points
304
def interpolate_camera(start, end, l):
305
  def i(a, b):
306
    return b * l + a * (1 - l)
307

308
  end_theta = end.theta
309
  if end.theta - start.theta > 180:
310
    end_theta -= 360
311
  if start.theta - end.theta > 180:
312
    end_theta += 360
313
  return Camera(
314
      i(start.x, end.x),
315
      i(start.y, end.y),
316
      i(start.z, end.z),
317
      i(start.theta, end_theta),
318
      i(start.psi, end.psi),
319
  )
320

321

322
def ease(x):
323
  if x < 0.5:
324
    return 2 * x * x
325
  return 1 - 2 * (1 - x) * (1 - x)
326

327

328
def lerp(a, b, l):
329
  return a * (1 - l) + b * l
330

331

332
def random_camera(tlim=16, psi_multiplier=20):
333
  height = random.uniform(0, 2)
334
  psi = -psi_multiplier * height
335
  return Camera(
336
      random.uniform(-tlim, tlim),
337
      height,
338
      random.uniform(-tlim, tlim),
339
      random.uniform(0, 360),
340
      psi,
341
  )
342

343

344
def visualize_rays(G_terrain, Rt, xyz, layout, display_size, cam_grid=None):
345
  """Return an image showing the camera rays projected onto X-Z plane."""
346
  # Rt = world2cam matrix
347

348
  if hasattr(G_terrain, 'layout_generator'):
349
    # layout model
350
    global_feat_res = G_terrain.layout_decoder.global_feat_res
351
    coordinate_scale = G_terrain.coordinate_scale
352
  else:
353
    # triplane model
354
    global_feat_res = G_terrain.backbone_xz.img_resolution
355
    coordinate_scale = G_terrain.rendering_kwargs['box_warp']
356

357
  inference_feat_res = layout.shape[-1]
358

359
  # compute pixel locations for camera points
360
  cam_frustum = xyz / (coordinate_scale / 2)  # normalize to [-1, 1]
361
  cam_frustum = (
362
      cam_frustum * global_feat_res / inference_feat_res
363
  )  # rescale for extended spatial grid
364
  cam_frustum = (cam_frustum + 1) / 2  # normalize to [0, 1]
365
  cam_frustum = (
366
      (cam_frustum * display_size).long().clamp(0, display_size - 1)
367
  )  # convert to [0, display size]
368

369
  # compute pixel locations for camera center
370
  tform_cam2world = Rt.inverse()
371
  cam_center = tform_cam2world[0, :3, -1]
372
  cam_center = cam_center / (coordinate_scale / 2)
373
  cam_center = (
374
      cam_center * global_feat_res / inference_feat_res
375
  )  # rescale for extended spatial grid
376
  cam_center = (cam_center + 1) / 2
377
  cam_center = (cam_center * display_size).long().clamp(0, display_size - 1)
378

379
  # compute pixel locations for white box representing training region
380
  orig_grid_offset = torch.Tensor([
381
      -1,
382
  ]) * (global_feat_res / inference_feat_res)
383
  orig_grid_offset = (orig_grid_offset + 1) / 2
384
  orig_grid_offset = (
385
      (orig_grid_offset * display_size).long().clamp(0, display_size - 1)
386
  )  # convert to [0, display size]
387

388
  if cam_grid is None:
389
    cam_grid = torch.zeros(3, display_size, display_size)
390
  else:
391
    cam_grid = cam_grid.clone().cpu()
392

393
  # plot everything on image
394
  cam_grid[
395
      1, cam_frustum[Ellipsis, 2].reshape(-1), cam_frustum[Ellipsis, 0].reshape(-1)
396
  ] = 1
397
  cam_grid[
398
      0, cam_frustum[Ellipsis, 2].reshape(-1), cam_frustum[Ellipsis, 0].reshape(-1)
399
  ] = 0.5
400
  cam_grid[:, orig_grid_offset, orig_grid_offset:-orig_grid_offset] = 1
401
  cam_grid[:, -orig_grid_offset, orig_grid_offset:-orig_grid_offset] = 1
402
  cam_grid[:, orig_grid_offset:-orig_grid_offset, orig_grid_offset] = 1
403
  cam_grid[:, orig_grid_offset:-orig_grid_offset, -orig_grid_offset] = 1
404
  cam_grid[
405
      0,
406
      cam_center[2] - 2 : cam_center[2] + 2,
407
      cam_center[0] - 2 : cam_center[0] + 2,
408
  ] = 1
409
  return cam_grid
410

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.