google-research

Форк
0
478 строк · 17.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utility functions."""
17
import collections
18
import os
19
from os import path
20
from absl import flags
21
import flax
22
import flax.optim
23
import jax
24
import jax.numpy as jnp
25
import jax.scipy as jsp
26
import numpy as np
27
from PIL import Image
28
import yaml
29
from snerg.nerf import datasets
30

31
BASE_DIR = "snerg"
32
INTERNAL = False
33

34

35
@flax.struct.dataclass
36
class TrainState:
37
  optimizer: flax.optim.Optimizer
38

39

40
@flax.struct.dataclass
41
class Stats:
42
  loss: float
43
  psnr: float
44
  loss_c: float
45
  psnr_c: float
46
  weight_l2: float
47
  sparsity: float
48
  sparsity_c: float
49

50

51
Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs"))
52

53

54
def namedtuple_map(fn, tup):
55
  """Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
56
  return type(tup)(*map(fn, tup))
57

58

59
def define_flags():
60
  """Define flags for both training and evaluation modes."""
61
  flags.DEFINE_string("train_dir", None, "where to store ckpts and logs")
62
  flags.DEFINE_string("data_dir", None, "input data directory.")
63
  flags.DEFINE_string("config", None,
64
                      "using config files to set hyperparameters.")
65

66
  # Dataset Flags
67
  # TODO(pratuls): rename to dataset_loader and consider cleaning up
68
  flags.DEFINE_enum("dataset", "blender",
69
                    list(k for k in datasets.dataset_dict.keys()),
70
                    "The type of dataset feed to nerf.")
71
  flags.DEFINE_enum(
72
      "batching", "single_image", ["single_image", "all_images"],
73
      "source of ray sampling when collecting training batch,"
74
      "single_image for sampling from only one image in a batch,"
75
      "all_images for sampling from all the training images.")
76
  flags.DEFINE_bool(
77
      "white_bkgd", True, "using white color as default background."
78
      "(used in the blender dataset only)")
79
  flags.DEFINE_integer("batch_size", 1024,
80
                       "the number of rays in a mini-batch (for training).")
81
  flags.DEFINE_integer("factor", 4,
82
                       "the downsample factor of images, 0 for no downsample.")
83
  flags.DEFINE_bool("spherify", False, "set for spherical 360 scenes.")
84
  flags.DEFINE_bool(
85
      "render_path", False, "render generated path if set true."
86
      "(used in the llff dataset only)")
87
  flags.DEFINE_integer(
88
      "llffhold", 8, "will take every 1/N images as LLFF test set."
89
      "(used in the llff dataset only)")
90
  flags.DEFINE_bool(
91
      "use_pixel_centers", False,
92
      "If True, generate rays through the center of each pixel. Note: While "
93
      "this is the correct way to handle rays, it is not the way rays are "
94
      "handled in the original NeRF paper. Setting this TRUE yields ~ +1 PSNR "
95
      "compared to Vanilla NeRF.")
96

97
  # Model Flags
98
  flags.DEFINE_string("model", "nerf", "name of model to use.")
99
  flags.DEFINE_float("near", 2., "near clip of volumetric rendering.")
100
  flags.DEFINE_float("far", 6., "far clip of volumentric rendering.")
101
  flags.DEFINE_integer("net_depth", 8, "depth of the first part of MLP.")
102
  flags.DEFINE_integer("net_width", 256, "width of the first part of MLP.")
103
  flags.DEFINE_integer("num_viewdir_channels", 4,
104
                       "number of extra channels used for view-dependence.")
105
  flags.DEFINE_integer("viewdir_net_depth", 2,
106
                       "depth of the view-dependence MLP.")
107
  flags.DEFINE_integer("viewdir_net_width", 16,
108
                       "width of the view-dependence MLP.")
109
  flags.DEFINE_float("weight_decay_mult", 0, "The multiplier on weight decay")
110
  flags.DEFINE_integer(
111
      "skip_layer", 4, "add a skip connection to the output vector of every"
112
      "skip_layer layers.")
113
  flags.DEFINE_integer("num_rgb_channels", 3, "the number of RGB channels.")
114
  flags.DEFINE_integer("num_sigma_channels", 1,
115
                       "the number of density channels.")
116
  flags.DEFINE_bool("randomized", True, "use randomized stratified sampling.")
117
  flags.DEFINE_integer("min_deg_point", 0,
118
                       "Minimum degree of positional encoding for points.")
119
  flags.DEFINE_integer("max_deg_point", 10,
120
                       "Maximum degree of positional encoding for points.")
121
  flags.DEFINE_integer("deg_view", 4,
122
                       "Degree of positional encoding for viewdirs.")
123
  flags.DEFINE_integer(
124
      "num_coarse_samples", 64,
125
      "the number of samples on each ray for the coarse model.")
126
  flags.DEFINE_integer("num_fine_samples", 128,
127
                       "the number of samples on each ray for the fine model.")
128
  flags.DEFINE_bool("use_viewdirs", True, "use view directions as a condition.")
129
  flags.DEFINE_float(
130
      "noise_std", None, "std dev of noise added to regularize sigma output."
131
      "(used in the llff dataset only)")
132
  flags.DEFINE_float(
133
      "sparsity_strength", 0.0, "weight for the sparsity loss"
134
      "(Cauchy on density).")
135
  flags.DEFINE_bool("lindisp", False,
136
                    "sampling linearly in disparity rather than depth.")
137
  flags.DEFINE_bool(
138
      "legacy_posenc_order", False,
139
      "If True, revert the positional encoding feature order to an older version of this codebase."
140
  )
141

142
  # Train Flags
143
  flags.DEFINE_float("lr_init", 5e-4, "The initial learning rate.")
144
  flags.DEFINE_float("lr_final", 5e-6, "The final learning rate.")
145
  flags.DEFINE_integer(
146
      "lr_delay_steps", 0, "The number of steps at the beginning of "
147
      "training to reduce the learning rate by lr_delay_mult")
148
  flags.DEFINE_float(
149
      "lr_delay_mult", 1., "A multiplier on the learning rate when the step "
150
      "is < lr_delay_steps")
151
  flags.DEFINE_float("grad_max_norm", 0.,
152
                     "The gradient clipping magnitude (disabled if == 0).")
153
  flags.DEFINE_float("grad_max_val", 0.,
154
                     "The gradient clipping value (disabled if == 0).")
155

156
  flags.DEFINE_integer("max_steps", 1000000,
157
                       "the number of optimization steps.")
158
  flags.DEFINE_integer("save_every", 10000,
159
                       "the number of steps to save a checkpoint.")
160
  flags.DEFINE_integer("print_every", 100,
161
                       "the number of steps between reports to tensorboard.")
162
  flags.DEFINE_integer(
163
      "render_every", 5000, "the number of steps to render a test image,"
164
      "better to be x00 for accurate step time record.")
165
  flags.DEFINE_integer("gc_every", 10000,
166
                       "the number of steps to run python garbage collection.")
167

168
  # Eval Flags
169
  flags.DEFINE_bool(
170
      "eval_once", True,
171
      "evaluate the model only once if true, otherwise keeping evaluating new"
172
      "checkpoints if there's any.")
173
  flags.DEFINE_bool("save_output", True,
174
                    "save predicted images to disk if True.")
175
  flags.DEFINE_integer(
176
      "chunk", 8192,
177
      "the size of chunks for evaluation inferences, set to the value that"
178
      "fits your GPU/TPU memory.")
179

180
  # Baking flags used for SNeRG
181
  flags.DEFINE_float(
182
      "voxel_filter_sigma", 1.0 / np.sqrt(12),
183
      "To prevent aliasing, we prefilter the NeRF volume with a 3D Gaussian "
184
      "on XYZ. This sigma controls how much to blur"
185
  )
186
  flags.DEFINE_integer("num_samples_per_voxel", 16,
187
                       "How many samples to use for the spatial prefilter.")
188
  flags.DEFINE_integer(
189
      "voxel_resolution", 1000,
190
      "The resolution of the voxel grid along the longest edge of the volume.")
191
  flags.DEFINE_integer(
192
      "snerg_chunk_size", 2**18,
193
      "The number of network queries to perform at a time. Lower this number "
194
      "if you run out of GPU or TPU memory."
195
  )
196
  flags.DEFINE_integer("culling_block_size", 16,
197
                       "The block size used for visibility and alpha culling.")
198
  flags.DEFINE_float(
199
      "alpha_threshold", 0.005,
200
      "We discard any atlas blocks where the max alpha is below this threshold."
201
  )
202
  flags.DEFINE_float(
203
      "visibility_threshold", 0.01,
204
      "We threshold on visiblity = max(1 - alpha_along_ray), to create a "
205
      "visibility mask for all atlas blocks in the scene."
206
  )
207
  flags.DEFINE_float(
208
      "visibility_image_factor", 0.125,
209
      "Speedup: Scale the images by this factor when computing visibilities.")
210
  flags.DEFINE_integer(
211
      "visibility_subsample_factor", 2,
212
      "Speedup: Only process every Nth image when computing visibilities.")
213
  flags.DEFINE_integer(
214
      "visibility_grid_dilation", 2,
215
      "This makes the visibility grid conservative by dilating it slightly.")
216
  flags.DEFINE_integer(
217
      "atlas_block_size", 32,
218
      "The side length of a block stored in the volumetric textuer atlas. Make "
219
      "sure to use a multiple of 16, so this fits nicely within image/video "
220
      "compression macroblocks.")
221
  flags.DEFINE_integer(
222
      "atlas_slice_size", 2048,
223
      "We store the atlas as a collection of atlas_size * atlas_size images.")
224
  flags.DEFINE_bool(
225
      "flip_scene_coordinates", True,
226
      "If the scenes have been defined in OpenCV coordnates (y-down, "
227
      "z-forward, e.g. blender or the default coordinate space for COLMAP), "
228
      "setting this to true will align the SNeRG grid with OpenGL coordinate "
229
      "conventions (y-up, z-backward).")
230
  flags.DEFINE_float(
231
      "snerg_box_scale", 1.7,
232
      "Scales the SNeRG voxel grid to fit the scene in [-scale, scale]^3.")
233
  flags.DEFINE_string(
234
      "snerg_dtype", "float32",
235
      "Data-type used in the 3D texture atlas, float16 may conserve CPU RAM.")
236

237

238
def update_flags(args):
239
  """Update the flags in `args` with the contents of the config YAML file."""
240
  pth = path.join(BASE_DIR, args.config + ".yaml")
241
  with open_file(pth, "r") as fin:
242
    configs = yaml.load(fin, Loader=yaml.FullLoader)
243
  # Only allow args to be updated if they already exist.
244
  invalid_args = list(set(configs.keys()) - set(dir(args)))
245
  if invalid_args:
246
    raise ValueError(f"Invalid args {invalid_args} in {pth}.")
247
  args.__dict__.update(configs)
248

249

250
def open_file(pth, mode="r"):
251
  if not INTERNAL:
252
    return open(pth, mode=mode)
253

254

255
def file_exists(pth):
256
  if not INTERNAL:
257
    return path.exists(pth)
258

259

260
def listdir(pth):
261
  if not INTERNAL:
262
    return os.listdir(pth)
263

264

265
def isdir(pth):
266
  if not INTERNAL:
267
    return path.isdir(pth)
268

269

270
def makedirs(pth):
271
  if not INTERNAL:
272
    os.makedirs(pth)
273

274

275
def render_image(render_fn, rays, rng, normalize_disp, chunk=8192):
276
  """Render all the pixels of an image (in test mode).
277

278
  Args:
279
    render_fn: function, jit-ed render function.
280
    rays: a `Rays` namedtuple, the rays to be rendered.
281
    rng: jnp.ndarray, random number generator (used in training mode only).
282
    normalize_disp: bool, if true then normalize `disp` to [0, 1].
283
    chunk: int, the size of chunks to render sequentially.
284

285
  Returns:
286
    rgb: jnp.ndarray, rendered color image.
287
    disp: jnp.ndarray, rendered disparity image.
288
    acc: jnp.ndarray, rendered accumulated weights per pixel.
289
    features: jnp.ndarray, rendered feature image.
290
    specular: jnp.ndarray, rendered specular residual.
291
  """
292
  height, width = rays[0].shape[:2]
293
  num_rays = height * width
294
  rays = namedtuple_map(lambda r: r.reshape((num_rays, -1)), rays)
295

296
  unused_rng, key_0, key_1 = jax.random.split(rng, 3)
297
  host_id = jax.host_id()
298
  results = []
299
  for i in range(0, num_rays, chunk):
300
    # pylint: disable=cell-var-from-loop
301
    chunk_rays = namedtuple_map(lambda r: r[i:i + chunk], rays)
302
    chunk_size = chunk_rays[0].shape[0]
303
    rays_remaining = chunk_size % jax.device_count()
304
    if rays_remaining != 0:
305
      padding = jax.device_count() - rays_remaining
306
      chunk_rays = namedtuple_map(
307
          lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"), chunk_rays)
308
    else:
309
      padding = 0
310
    # After padding the number of chunk_rays is always divisible by
311
    # host_count.
312
    rays_per_host = chunk_rays[0].shape[0] // jax.host_count()
313
    start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host
314
    chunk_rays = namedtuple_map(lambda r: shard(r[start:stop]), chunk_rays)
315
    chunk_results = render_fn(key_0, key_1, chunk_rays)[-1]
316
    results.append([unshard(x[0], padding) for x in chunk_results])
317
    # pylint: enable=cell-var-from-loop
318
  rgb, disp, acc, _, features, specular = [
319
      jnp.concatenate(r, axis=0) for r in zip(*results)
320
  ]
321
  # Normalize disp for visualization for ndc_rays in llff front-facing scenes.
322
  if normalize_disp:
323
    disp = (disp - disp.min()) / (disp.max() - disp.min())
324
  return (rgb.reshape((height, width, -1)), disp.reshape(
325
      (height, width, -1)), acc.reshape(
326
          (height, width, -1)), features.reshape(
327
              (height, width, -1)), specular.reshape((height, width, -1)))
328

329

330
def compute_psnr(mse):
331
  """Compute psnr value given mse (we assume the maximum pixel value is 1).
332

333
  Args:
334
    mse: float, mean square error of pixels.
335

336
  Returns:
337
    psnr: float, the psnr value.
338
  """
339
  return -10. * jnp.log(mse) / jnp.log(10.)
340

341

342
def compute_ssim(img0,
343
                 img1,
344
                 max_val,
345
                 filter_size=11,
346
                 filter_sigma=1.5,
347
                 k1=0.01,
348
                 k2=0.03,
349
                 return_map=False):
350
  """Computes SSIM from two images.
351

352
  This function was modeled after tf.image.ssim, and should produce comparable
353
  output.
354

355
  Args:
356
    img0: array. An image of size [..., width, height, num_channels].
357
    img1: array. An image of size [..., width, height, num_channels].
358
    max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
359
    filter_size: int >= 1. Window size.
360
    filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
361
    k1: float > 0. One of the SSIM dampening parameters.
362
    k2: float > 0. One of the SSIM dampening parameters.
363
    return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned
364

365
  Returns:
366
    Each image's mean SSIM, or a tensor of individual values if `return_map`.
367
  """
368
  # Construct a 1D Gaussian blur filter.
369
  hw = filter_size // 2
370
  shift = (2 * hw - filter_size + 1) / 2
371
  f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2
372
  filt = jnp.exp(-0.5 * f_i)
373
  filt /= jnp.sum(filt)
374

375
  # Blur in x and y (faster than the 2D convolution).
376
  filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid")
377
  filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid")
378

379
  # Vmap the blurs to the tensor size, and then compose them.
380
  num_dims = len(img0.shape)
381
  map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1])
382
  for d in map_axes:
383
    filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d)
384
    filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d)
385
  filt_fn = lambda z: filt_fn1(filt_fn2(z))
386

387
  mu0 = filt_fn(img0)
388
  mu1 = filt_fn(img1)
389
  mu00 = mu0 * mu0
390
  mu11 = mu1 * mu1
391
  mu01 = mu0 * mu1
392
  sigma00 = filt_fn(img0**2) - mu00
393
  sigma11 = filt_fn(img1**2) - mu11
394
  sigma01 = filt_fn(img0 * img1) - mu01
395

396
  # Clip the variances and covariances to valid values.
397
  # Variance must be non-negative:
398
  sigma00 = jnp.maximum(0., sigma00)
399
  sigma11 = jnp.maximum(0., sigma11)
400
  sigma01 = jnp.sign(sigma01) * jnp.minimum(
401
      jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01))
402

403
  c1 = (k1 * max_val)**2
404
  c2 = (k2 * max_val)**2
405
  numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
406
  denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
407
  ssim_map = numer / denom
408
  ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims)))
409
  return ssim_map if return_map else ssim
410

411

412
def save_img(img, pth):
413
  """Save an image to disk.
414

415
  Args:
416
    img: jnp.ndarry, [height, width, channels], img will be clipped to [0, 1]
417
      before saved to pth.
418
    pth: string, path to save the image to.
419
  """
420
  with open_file(pth, "wb") as imgout:
421
    Image.fromarray(np.array(
422
        (np.clip(img, 0., 1.) * 255.).astype(jnp.uint8))).save(imgout, "PNG")
423

424

425
def learning_rate_decay(step,
426
                        lr_init,
427
                        lr_final,
428
                        max_steps,
429
                        lr_delay_steps=0,
430
                        lr_delay_mult=1):
431
  """Continuous learning rate decay function.
432

433
  The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
434
  is log-linearly interpolated elsewhere (equivalent to exponential decay).
435
  If lr_delay_steps>0 then the learning rate will be scaled by some smooth
436
  function of lr_delay_mult, such that the initial learning rate is
437
  lr_init*lr_delay_mult at the beginning of optimization but will be eased back
438
  to the normal learning rate when steps>lr_delay_steps.
439

440
  Args:
441
    step: int, the current optimization step.
442
    lr_init: float, the initial learning rate.
443
    lr_final: float, the final learning rate.
444
    max_steps: int, the number of steps during optimization.
445
    lr_delay_steps: int, the number of steps to delay the full learning rate.
446
    lr_delay_mult: float, the multiplier on the rate when delaying it.
447

448
  Returns:
449
    lr: the learning for current step 'step'.
450
  """
451
  if lr_delay_steps > 0:
452
    # A kind of reverse cosine decay.
453
    delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
454
        0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1))
455
  else:
456
    delay_rate = 1.
457
  t = np.clip(step / max_steps, 0, 1)
458
  log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
459
  return delay_rate * log_lerp
460

461

462
def shard(xs):
463
  """Split data into shards for multiple devices along the first dimension."""
464
  return jax.tree_map(
465
      lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)
466

467

468
def to_device(xs):
469
  """Transfer data to devices (GPU/TPU)."""
470
  return jax.device_put(xs)
471

472

473
def unshard(x, padding=0):
474
  """Collect the sharded tensor to the shape before sharding."""
475
  y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:]))
476
  if padding > 0:
477
    y = y[:-padding]
478
  return y
479

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.