google-research
478 строк · 17.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utility functions."""
17import collections
18import os
19from os import path
20from absl import flags
21import flax
22import flax.optim
23import jax
24import jax.numpy as jnp
25import jax.scipy as jsp
26import numpy as np
27from PIL import Image
28import yaml
29from snerg.nerf import datasets
30
31BASE_DIR = "snerg"
32INTERNAL = False
33
34
35@flax.struct.dataclass
36class TrainState:
37optimizer: flax.optim.Optimizer
38
39
40@flax.struct.dataclass
41class Stats:
42loss: float
43psnr: float
44loss_c: float
45psnr_c: float
46weight_l2: float
47sparsity: float
48sparsity_c: float
49
50
51Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs"))
52
53
54def namedtuple_map(fn, tup):
55"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
56return type(tup)(*map(fn, tup))
57
58
59def define_flags():
60"""Define flags for both training and evaluation modes."""
61flags.DEFINE_string("train_dir", None, "where to store ckpts and logs")
62flags.DEFINE_string("data_dir", None, "input data directory.")
63flags.DEFINE_string("config", None,
64"using config files to set hyperparameters.")
65
66# Dataset Flags
67# TODO(pratuls): rename to dataset_loader and consider cleaning up
68flags.DEFINE_enum("dataset", "blender",
69list(k for k in datasets.dataset_dict.keys()),
70"The type of dataset feed to nerf.")
71flags.DEFINE_enum(
72"batching", "single_image", ["single_image", "all_images"],
73"source of ray sampling when collecting training batch,"
74"single_image for sampling from only one image in a batch,"
75"all_images for sampling from all the training images.")
76flags.DEFINE_bool(
77"white_bkgd", True, "using white color as default background."
78"(used in the blender dataset only)")
79flags.DEFINE_integer("batch_size", 1024,
80"the number of rays in a mini-batch (for training).")
81flags.DEFINE_integer("factor", 4,
82"the downsample factor of images, 0 for no downsample.")
83flags.DEFINE_bool("spherify", False, "set for spherical 360 scenes.")
84flags.DEFINE_bool(
85"render_path", False, "render generated path if set true."
86"(used in the llff dataset only)")
87flags.DEFINE_integer(
88"llffhold", 8, "will take every 1/N images as LLFF test set."
89"(used in the llff dataset only)")
90flags.DEFINE_bool(
91"use_pixel_centers", False,
92"If True, generate rays through the center of each pixel. Note: While "
93"this is the correct way to handle rays, it is not the way rays are "
94"handled in the original NeRF paper. Setting this TRUE yields ~ +1 PSNR "
95"compared to Vanilla NeRF.")
96
97# Model Flags
98flags.DEFINE_string("model", "nerf", "name of model to use.")
99flags.DEFINE_float("near", 2., "near clip of volumetric rendering.")
100flags.DEFINE_float("far", 6., "far clip of volumentric rendering.")
101flags.DEFINE_integer("net_depth", 8, "depth of the first part of MLP.")
102flags.DEFINE_integer("net_width", 256, "width of the first part of MLP.")
103flags.DEFINE_integer("num_viewdir_channels", 4,
104"number of extra channels used for view-dependence.")
105flags.DEFINE_integer("viewdir_net_depth", 2,
106"depth of the view-dependence MLP.")
107flags.DEFINE_integer("viewdir_net_width", 16,
108"width of the view-dependence MLP.")
109flags.DEFINE_float("weight_decay_mult", 0, "The multiplier on weight decay")
110flags.DEFINE_integer(
111"skip_layer", 4, "add a skip connection to the output vector of every"
112"skip_layer layers.")
113flags.DEFINE_integer("num_rgb_channels", 3, "the number of RGB channels.")
114flags.DEFINE_integer("num_sigma_channels", 1,
115"the number of density channels.")
116flags.DEFINE_bool("randomized", True, "use randomized stratified sampling.")
117flags.DEFINE_integer("min_deg_point", 0,
118"Minimum degree of positional encoding for points.")
119flags.DEFINE_integer("max_deg_point", 10,
120"Maximum degree of positional encoding for points.")
121flags.DEFINE_integer("deg_view", 4,
122"Degree of positional encoding for viewdirs.")
123flags.DEFINE_integer(
124"num_coarse_samples", 64,
125"the number of samples on each ray for the coarse model.")
126flags.DEFINE_integer("num_fine_samples", 128,
127"the number of samples on each ray for the fine model.")
128flags.DEFINE_bool("use_viewdirs", True, "use view directions as a condition.")
129flags.DEFINE_float(
130"noise_std", None, "std dev of noise added to regularize sigma output."
131"(used in the llff dataset only)")
132flags.DEFINE_float(
133"sparsity_strength", 0.0, "weight for the sparsity loss"
134"(Cauchy on density).")
135flags.DEFINE_bool("lindisp", False,
136"sampling linearly in disparity rather than depth.")
137flags.DEFINE_bool(
138"legacy_posenc_order", False,
139"If True, revert the positional encoding feature order to an older version of this codebase."
140)
141
142# Train Flags
143flags.DEFINE_float("lr_init", 5e-4, "The initial learning rate.")
144flags.DEFINE_float("lr_final", 5e-6, "The final learning rate.")
145flags.DEFINE_integer(
146"lr_delay_steps", 0, "The number of steps at the beginning of "
147"training to reduce the learning rate by lr_delay_mult")
148flags.DEFINE_float(
149"lr_delay_mult", 1., "A multiplier on the learning rate when the step "
150"is < lr_delay_steps")
151flags.DEFINE_float("grad_max_norm", 0.,
152"The gradient clipping magnitude (disabled if == 0).")
153flags.DEFINE_float("grad_max_val", 0.,
154"The gradient clipping value (disabled if == 0).")
155
156flags.DEFINE_integer("max_steps", 1000000,
157"the number of optimization steps.")
158flags.DEFINE_integer("save_every", 10000,
159"the number of steps to save a checkpoint.")
160flags.DEFINE_integer("print_every", 100,
161"the number of steps between reports to tensorboard.")
162flags.DEFINE_integer(
163"render_every", 5000, "the number of steps to render a test image,"
164"better to be x00 for accurate step time record.")
165flags.DEFINE_integer("gc_every", 10000,
166"the number of steps to run python garbage collection.")
167
168# Eval Flags
169flags.DEFINE_bool(
170"eval_once", True,
171"evaluate the model only once if true, otherwise keeping evaluating new"
172"checkpoints if there's any.")
173flags.DEFINE_bool("save_output", True,
174"save predicted images to disk if True.")
175flags.DEFINE_integer(
176"chunk", 8192,
177"the size of chunks for evaluation inferences, set to the value that"
178"fits your GPU/TPU memory.")
179
180# Baking flags used for SNeRG
181flags.DEFINE_float(
182"voxel_filter_sigma", 1.0 / np.sqrt(12),
183"To prevent aliasing, we prefilter the NeRF volume with a 3D Gaussian "
184"on XYZ. This sigma controls how much to blur"
185)
186flags.DEFINE_integer("num_samples_per_voxel", 16,
187"How many samples to use for the spatial prefilter.")
188flags.DEFINE_integer(
189"voxel_resolution", 1000,
190"The resolution of the voxel grid along the longest edge of the volume.")
191flags.DEFINE_integer(
192"snerg_chunk_size", 2**18,
193"The number of network queries to perform at a time. Lower this number "
194"if you run out of GPU or TPU memory."
195)
196flags.DEFINE_integer("culling_block_size", 16,
197"The block size used for visibility and alpha culling.")
198flags.DEFINE_float(
199"alpha_threshold", 0.005,
200"We discard any atlas blocks where the max alpha is below this threshold."
201)
202flags.DEFINE_float(
203"visibility_threshold", 0.01,
204"We threshold on visiblity = max(1 - alpha_along_ray), to create a "
205"visibility mask for all atlas blocks in the scene."
206)
207flags.DEFINE_float(
208"visibility_image_factor", 0.125,
209"Speedup: Scale the images by this factor when computing visibilities.")
210flags.DEFINE_integer(
211"visibility_subsample_factor", 2,
212"Speedup: Only process every Nth image when computing visibilities.")
213flags.DEFINE_integer(
214"visibility_grid_dilation", 2,
215"This makes the visibility grid conservative by dilating it slightly.")
216flags.DEFINE_integer(
217"atlas_block_size", 32,
218"The side length of a block stored in the volumetric textuer atlas. Make "
219"sure to use a multiple of 16, so this fits nicely within image/video "
220"compression macroblocks.")
221flags.DEFINE_integer(
222"atlas_slice_size", 2048,
223"We store the atlas as a collection of atlas_size * atlas_size images.")
224flags.DEFINE_bool(
225"flip_scene_coordinates", True,
226"If the scenes have been defined in OpenCV coordnates (y-down, "
227"z-forward, e.g. blender or the default coordinate space for COLMAP), "
228"setting this to true will align the SNeRG grid with OpenGL coordinate "
229"conventions (y-up, z-backward).")
230flags.DEFINE_float(
231"snerg_box_scale", 1.7,
232"Scales the SNeRG voxel grid to fit the scene in [-scale, scale]^3.")
233flags.DEFINE_string(
234"snerg_dtype", "float32",
235"Data-type used in the 3D texture atlas, float16 may conserve CPU RAM.")
236
237
238def update_flags(args):
239"""Update the flags in `args` with the contents of the config YAML file."""
240pth = path.join(BASE_DIR, args.config + ".yaml")
241with open_file(pth, "r") as fin:
242configs = yaml.load(fin, Loader=yaml.FullLoader)
243# Only allow args to be updated if they already exist.
244invalid_args = list(set(configs.keys()) - set(dir(args)))
245if invalid_args:
246raise ValueError(f"Invalid args {invalid_args} in {pth}.")
247args.__dict__.update(configs)
248
249
250def open_file(pth, mode="r"):
251if not INTERNAL:
252return open(pth, mode=mode)
253
254
255def file_exists(pth):
256if not INTERNAL:
257return path.exists(pth)
258
259
260def listdir(pth):
261if not INTERNAL:
262return os.listdir(pth)
263
264
265def isdir(pth):
266if not INTERNAL:
267return path.isdir(pth)
268
269
270def makedirs(pth):
271if not INTERNAL:
272os.makedirs(pth)
273
274
275def render_image(render_fn, rays, rng, normalize_disp, chunk=8192):
276"""Render all the pixels of an image (in test mode).
277
278Args:
279render_fn: function, jit-ed render function.
280rays: a `Rays` namedtuple, the rays to be rendered.
281rng: jnp.ndarray, random number generator (used in training mode only).
282normalize_disp: bool, if true then normalize `disp` to [0, 1].
283chunk: int, the size of chunks to render sequentially.
284
285Returns:
286rgb: jnp.ndarray, rendered color image.
287disp: jnp.ndarray, rendered disparity image.
288acc: jnp.ndarray, rendered accumulated weights per pixel.
289features: jnp.ndarray, rendered feature image.
290specular: jnp.ndarray, rendered specular residual.
291"""
292height, width = rays[0].shape[:2]
293num_rays = height * width
294rays = namedtuple_map(lambda r: r.reshape((num_rays, -1)), rays)
295
296unused_rng, key_0, key_1 = jax.random.split(rng, 3)
297host_id = jax.host_id()
298results = []
299for i in range(0, num_rays, chunk):
300# pylint: disable=cell-var-from-loop
301chunk_rays = namedtuple_map(lambda r: r[i:i + chunk], rays)
302chunk_size = chunk_rays[0].shape[0]
303rays_remaining = chunk_size % jax.device_count()
304if rays_remaining != 0:
305padding = jax.device_count() - rays_remaining
306chunk_rays = namedtuple_map(
307lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"), chunk_rays)
308else:
309padding = 0
310# After padding the number of chunk_rays is always divisible by
311# host_count.
312rays_per_host = chunk_rays[0].shape[0] // jax.host_count()
313start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host
314chunk_rays = namedtuple_map(lambda r: shard(r[start:stop]), chunk_rays)
315chunk_results = render_fn(key_0, key_1, chunk_rays)[-1]
316results.append([unshard(x[0], padding) for x in chunk_results])
317# pylint: enable=cell-var-from-loop
318rgb, disp, acc, _, features, specular = [
319jnp.concatenate(r, axis=0) for r in zip(*results)
320]
321# Normalize disp for visualization for ndc_rays in llff front-facing scenes.
322if normalize_disp:
323disp = (disp - disp.min()) / (disp.max() - disp.min())
324return (rgb.reshape((height, width, -1)), disp.reshape(
325(height, width, -1)), acc.reshape(
326(height, width, -1)), features.reshape(
327(height, width, -1)), specular.reshape((height, width, -1)))
328
329
330def compute_psnr(mse):
331"""Compute psnr value given mse (we assume the maximum pixel value is 1).
332
333Args:
334mse: float, mean square error of pixels.
335
336Returns:
337psnr: float, the psnr value.
338"""
339return -10. * jnp.log(mse) / jnp.log(10.)
340
341
342def compute_ssim(img0,
343img1,
344max_val,
345filter_size=11,
346filter_sigma=1.5,
347k1=0.01,
348k2=0.03,
349return_map=False):
350"""Computes SSIM from two images.
351
352This function was modeled after tf.image.ssim, and should produce comparable
353output.
354
355Args:
356img0: array. An image of size [..., width, height, num_channels].
357img1: array. An image of size [..., width, height, num_channels].
358max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
359filter_size: int >= 1. Window size.
360filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
361k1: float > 0. One of the SSIM dampening parameters.
362k2: float > 0. One of the SSIM dampening parameters.
363return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned
364
365Returns:
366Each image's mean SSIM, or a tensor of individual values if `return_map`.
367"""
368# Construct a 1D Gaussian blur filter.
369hw = filter_size // 2
370shift = (2 * hw - filter_size + 1) / 2
371f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2
372filt = jnp.exp(-0.5 * f_i)
373filt /= jnp.sum(filt)
374
375# Blur in x and y (faster than the 2D convolution).
376filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid")
377filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid")
378
379# Vmap the blurs to the tensor size, and then compose them.
380num_dims = len(img0.shape)
381map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1])
382for d in map_axes:
383filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d)
384filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d)
385filt_fn = lambda z: filt_fn1(filt_fn2(z))
386
387mu0 = filt_fn(img0)
388mu1 = filt_fn(img1)
389mu00 = mu0 * mu0
390mu11 = mu1 * mu1
391mu01 = mu0 * mu1
392sigma00 = filt_fn(img0**2) - mu00
393sigma11 = filt_fn(img1**2) - mu11
394sigma01 = filt_fn(img0 * img1) - mu01
395
396# Clip the variances and covariances to valid values.
397# Variance must be non-negative:
398sigma00 = jnp.maximum(0., sigma00)
399sigma11 = jnp.maximum(0., sigma11)
400sigma01 = jnp.sign(sigma01) * jnp.minimum(
401jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01))
402
403c1 = (k1 * max_val)**2
404c2 = (k2 * max_val)**2
405numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
406denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
407ssim_map = numer / denom
408ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims)))
409return ssim_map if return_map else ssim
410
411
412def save_img(img, pth):
413"""Save an image to disk.
414
415Args:
416img: jnp.ndarry, [height, width, channels], img will be clipped to [0, 1]
417before saved to pth.
418pth: string, path to save the image to.
419"""
420with open_file(pth, "wb") as imgout:
421Image.fromarray(np.array(
422(np.clip(img, 0., 1.) * 255.).astype(jnp.uint8))).save(imgout, "PNG")
423
424
425def learning_rate_decay(step,
426lr_init,
427lr_final,
428max_steps,
429lr_delay_steps=0,
430lr_delay_mult=1):
431"""Continuous learning rate decay function.
432
433The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
434is log-linearly interpolated elsewhere (equivalent to exponential decay).
435If lr_delay_steps>0 then the learning rate will be scaled by some smooth
436function of lr_delay_mult, such that the initial learning rate is
437lr_init*lr_delay_mult at the beginning of optimization but will be eased back
438to the normal learning rate when steps>lr_delay_steps.
439
440Args:
441step: int, the current optimization step.
442lr_init: float, the initial learning rate.
443lr_final: float, the final learning rate.
444max_steps: int, the number of steps during optimization.
445lr_delay_steps: int, the number of steps to delay the full learning rate.
446lr_delay_mult: float, the multiplier on the rate when delaying it.
447
448Returns:
449lr: the learning for current step 'step'.
450"""
451if lr_delay_steps > 0:
452# A kind of reverse cosine decay.
453delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
4540.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1))
455else:
456delay_rate = 1.
457t = np.clip(step / max_steps, 0, 1)
458log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
459return delay_rate * log_lerp
460
461
462def shard(xs):
463"""Split data into shards for multiple devices along the first dimension."""
464return jax.tree_map(
465lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)
466
467
468def to_device(xs):
469"""Transfer data to devices (GPU/TPU)."""
470return jax.device_put(xs)
471
472
473def unshard(x, padding=0):
474"""Collect the sharded tensor to the shape before sharding."""
475y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:]))
476if padding > 0:
477y = y[:-padding]
478return y
479