google-research
3665 строк · 132.9 Кб
1{
2"cells": [
3{
4"cell_type": "markdown",
5"metadata": {
6"id": "JndnmDMp66FL"
7},
8"source": [
9"Copyright 2018 Google LLC.\n",
10"\n",
11"Licensed under the Apache License, Version 2.0 (the \"License\");"
12]
13},
14{
15"cell_type": "code",
16"execution_count": null,
17"metadata": {
18"id": "hMqWDc_m6rUC"
19},
20"outputs": [],
21"source": [
22"#@title Default title text\n",
23"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
24"# you may not use this file except in compliance with the License.\n",
25"# You may obtain a copy of the License at\n",
26"#\n",
27"# https://www.apache.org/licenses/LICENSE-2.0\n",
28"#\n",
29"# Unless required by applicable law or agreed to in writing, software\n",
30"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
31"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
32"# See the License for the specific language governing permissions and\n",
33"# limitations under the License."
34]
35},
36{
37"cell_type": "code",
38"execution_count": null,
39"metadata": {
40"id": "u3E9yyHoT8KO"
41},
42"outputs": [],
43"source": [
44"import jax.numpy as jnp\n",
45"import jax\n",
46"import numpy as np\n",
47"from PIL import Image\n",
48"import matplotlib.pyplot as plt\n",
49"import plotly.graph_objects as go\n",
50"import functools\n",
51"import jax.experimental.optimizers\n",
52"import time\n",
53"import flax\n",
54"import flax.linen as nn\n",
55"from typing import Sequence, Callable\n",
56"from IPython.display import clear_output\n",
57"import cv2\n",
58"import imageio\n",
59"import mediapy as media\n",
60"\n",
61"import os\n",
62"\n"
63]
64},
65{
66"cell_type": "code",
67"execution_count": null,
68"metadata": {
69"id": "qK31kBAtIJ3n"
70},
71"outputs": [],
72"source": [
73"def linear_to_srgb(linear):\n",
74" \"\"\"Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.\"\"\"\n",
75" eps = jnp.finfo(jnp.float32).eps\n",
76" srgb0 = 323 / 25 * linear\n",
77" srgb1 = (211 * jnp.maximum(eps, linear)**(5 / 12) - 11) / 200\n",
78" return jnp.where(linear \u003c= 0.0031308, srgb0, srgb1)\n",
79"\n",
80"def read_envmap(filename):\n",
81" with open(filename, 'rb') as f:\n",
82" return imageio.imread(f, 'exr')\n",
83"\n",
84"#envmap_linear = read_envmap(f'{DIRECTORY}}/ninomaru_teien_4k.exr')\n",
85"#envmap_linear = read_envmap(f'{DIRECTORY}/spruit_sunrise_4k.exr')\n",
86"#envmap_linear = read_envmap(f'{DIRECTORY}/hotel_room_4k.exr')\n",
87"#envmap_linear = read_envmap(f'{DIRECTORY}/spruit_sunrise_50x99.exr')\n",
88"envmap_linear = read_envmap(f'{DIRECTORY}/hotel_room_50x99.exr')\n",
89"\n",
90"envmap_linear = np.fliplr(envmap_linear) # Blender flips this for some reason\n",
91"#envmap_linear = np.roll(envmap_linear, envmap_linear.shape[1]//2, axis=1)\n",
92"envmap_srgb = linear_to_srgb(envmap_linear)\n",
93"\n",
94"plt.imshow(envmap_srgb)"
95]
96},
97{
98"cell_type": "code",
99"execution_count": null,
100"metadata": {
101"id": "hvG8F8ADKMHb"
102},
103"outputs": [],
104"source": [
105"envmap_H = 50\n",
106"envmap_W = envmap_H * 2 - 1\n",
107"#envmap_H, envmap_W = envmap_linear.shape[:2]\n",
108"\n",
109"envmap_gt = envmap_linear\n",
110"#envmap_gt = linear_to_srgb(cv2.resize(envmap_linear, dsize=(envmap_W, envmap_H), interpolation=cv2.INTER_AREA))\n",
111"#envmap_gt = cv2.resize(envmap_linear, dsize=(envmap_W, envmap_H), interpolation=cv2.INTER_AREA)\n",
112"plt.imshow(envmap_gt)\n",
113"plt.axis('off')"
114]
115},
116{
117"cell_type": "code",
118"execution_count": null,
119"metadata": {
120"id": "O6RPcOJFowWR"
121},
122"outputs": [],
123"source": [
124"#with open(f'{DIRECTORY}/hotel_room_{envmap_H}x{envmap_W}.exr', 'wb') as f:\n",
125"# imageio.imsave(f, envmap_gt, 'exr')\n"
126]
127},
128{
129"cell_type": "code",
130"execution_count": null,
131"metadata": {
132"id": "hrCaejuLFcpI"
133},
134"outputs": [],
135"source": [
136"#envmap_H, envmap_W = envmap_linear.shape[:2]\n",
137"#omega_phi, omega_theta = jnp.meshgrid(jnp.linspace(-jnp.pi, jnp.pi, envmap_W+1)[:-1],\n",
138"# jnp.linspace(0.0, jnp.pi, envmap_H+1)[:-1])\n",
139"omega_phi, omega_theta = jnp.meshgrid(jnp.linspace(-jnp.pi, jnp.pi, envmap_W+1)[:-1] + 2.0 * jnp.pi / (2.0 * envmap_W),\n",
140" jnp.linspace(0.0, jnp.pi, envmap_H+1)[:-1] + jnp.pi / (2.0 * envmap_H))\n",
141"\n",
142"dtheta_dphi = (omega_theta[1, 1] - omega_theta[0, 0]) * (omega_phi[1, 1] - omega_phi[0, 0])\n",
143"\n",
144"omega_theta = omega_theta.flatten()\n",
145"omega_phi = omega_phi.flatten()\n",
146"\n",
147"omega_x = jnp.sin(omega_theta) * jnp.cos(omega_phi)\n",
148"omega_y = jnp.sin(omega_theta) * jnp.sin(omega_phi)\n",
149"omega_z = jnp.cos(omega_theta)\n",
150"omega_xyz = jnp.stack([omega_x,\n",
151" omega_y,\n",
152" omega_z], axis=-1)\n",
153"\n"
154]
155},
156{
157"cell_type": "code",
158"execution_count": null,
159"metadata": {
160"id": "gJbvom3l620T"
161},
162"outputs": [],
163"source": [
164"def mse_to_psnr(mse):\n",
165" \"\"\"Compute PSNR given an MSE (we assume the maximum pixel value is 1).\"\"\"\n",
166" return -10. / jnp.log(10.) * jnp.log(mse)\n",
167" \n",
168"def get_rays(H, W, focal, c2w, rand_ort=False, key=None):\n",
169" \"\"\"\n",
170" c2w: 4x4 matrix\n",
171" output: two arrays of shape [H, W, 3]\n",
172" \"\"\"\n",
173" j, i = jnp.meshgrid(jnp.arange(W, dtype=jnp.float32),\n",
174" jnp.arange(H, dtype=jnp.float32))\n",
175" \n",
176" if rand_ort:\n",
177" k1, k2 = random.split(key)\n",
178"\n",
179" i += jax.random.uniform(k1, shape=(H, W)) - 0.5\n",
180" j += jax.random.uniform(k2, shape=(H, W)) - 0.5\n",
181" \n",
182" dirs = jnp.stack([ (j.flatten()-0.5*W)/focal,\n",
183" -(i.flatten()-0.5*H)/focal,\n",
184" -jnp.ones((H*W,), dtype=jnp.float32)], -1) # shape [HW, 3]\n",
185" \n",
186" rays_d = dirs @ c2w[:3, :3].T # shape [HW, 3]\n",
187" rays_o = c2w[:3,-1:].T.repeat(H*W, 0)\n",
188" return rays_o.reshape(H, W, 3), rays_d.reshape(H, W, 3)\n"
189]
190},
191{
192"cell_type": "code",
193"execution_count": null,
194"metadata": {
195"id": "bbHRuTQL7GZS"
196},
197"outputs": [],
198"source": [
199"def parse_bin(s):\n",
200" return int(s[1:], 2) / 2.**(len(s) - 1)\n",
201"\n",
202"\n",
203"def phi2(i):\n",
204" return parse_bin('.' + f'{i:b}'[::-1])\n",
205"\n",
206"def nice_uniform(N):\n",
207" u = []\n",
208" v = []\n",
209" for i in range(N):\n",
210" u.append(i / float(N))\n",
211" v.append(phi2(i))\n",
212" #pts.append((i/float(N), phi2(i)))\n",
213"\n",
214" return u, v\n",
215"\n",
216"def nice_uniform_spherical(N, hemisphere=True):\n",
217" \"\"\"implementation of http://holger.dammertz.org/stuff/notes_HammersleyOnHemisphere.html\"\"\"\n",
218" u, v = nice_uniform(N)\n",
219"\n",
220" theta = np.arccos(1.0 - np.array(u)) * (2.0 - int(hemisphere))\n",
221" phi = 2.0 * np.pi * np.array(v)\n",
222"\n",
223" return theta, phi\n",
224" \n",
225"hemisphere = True\n",
226"def get_all_camera_rays(N_cameras, camera_dist, H, W, focal):\n",
227" theta, phi = nice_uniform_spherical(N_cameras, hemisphere)\n",
228"\n",
229" camera_x_vec = np.sin(theta) * np.cos(phi)\n",
230" camera_y_vec = np.sin(theta) * np.sin(phi)\n",
231" camera_z_vec = np.cos(theta)\n",
232"\n",
233" rays_o_vec = []\n",
234" rays_d_vec = []\n",
235" cameras = []\n",
236" for i in range(N_cameras):\n",
237" camera = np.eye(4)\n",
238" camera[0, 3] = camera_x_vec[i] * camera_dist\n",
239" camera[1, 3] = camera_y_vec[i] * camera_dist\n",
240" camera[2, 3] = camera_z_vec[i] * camera_dist\n",
241"\n",
242" zdir = np.array([camera_x_vec[i], camera_y_vec[i], camera_z_vec[i]])\n",
243" zdir /= np.linalg.norm(zdir)\n",
244"\n",
245" ydir = np.array([0.0, 0.0, 1.0])\n",
246" ydir -= zdir * zdir.dot(ydir)\n",
247" ydir[0] += 1e-10 # make sure that cameras pointing straight down/up have a defined ydir\n",
248" ydir /= np.linalg.norm(ydir)\n",
249"\n",
250" xdir = np.cross(ydir, zdir)\n",
251"\n",
252"\n",
253" camera[:3, 0] = xdir\n",
254" camera[:3, 1] = ydir\n",
255" camera[:3, 2] = zdir\n",
256"\n",
257" cameras.append(camera)\n",
258"\n",
259" rays_o, rays_d = get_rays(H, W, focal, camera)\n",
260"\n",
261" rays_o_vec.append(rays_o)\n",
262" rays_d_vec.append(rays_d)\n",
263"\n",
264" rays_o_vec = jnp.stack(rays_o_vec, 0)\n",
265" rays_d_vec = jnp.stack(rays_d_vec, 0)\n",
266"\n",
267" return rays_o_vec, rays_d_vec"
268]
269},
270{
271"cell_type": "code",
272"execution_count": null,
273"metadata": {
274"id": "lsBMjzCg65dk"
275},
276"outputs": [],
277"source": [
278"def render_pixel(normal, lobe, envmap, mask):\n",
279" masked_envmap = envmap * mask[:, :, None]\n",
280" return (masked_envmap * lobe * jnp.sin(omega_theta).reshape(envmap_H, envmap_W, 1)).sum(0).sum(0) * dtheta_dphi / jnp.pi\n",
281"\n",
282"\n",
283"def render(envmap, mask, materials, normals, rays_d, alpha, shading='lambertian'):\n",
284" \"\"\"\n",
285" envmap: shape [h, w, 3]\n",
286" mask: shape [h, w]\n",
287" materials: dictionary with entries of shape [N, 3]\n",
288" normals: shape [N, 3]\n",
289" rays_d: shape [N, 3]\n",
290" alpha: shape [N, 1]\n",
291" \n",
292" output: rendered colors, shape [N, 3]\n",
293" \"\"\"\n",
294" \n",
295" assert shading in ['lambertian', 'phong', 'blinnphong']\n",
296"\n",
297" if shading in ['lambertian', 'phong', 'blinnphong']:\n",
298" # TODO: Feed in only pixels where alpha = 1\n",
299" lobes = jnp.maximum(0.0, (omega_xyz.reshape(1, envmap_H, envmap_W, 3) * normals[:, None, None, :]).sum(-1, keepdims=True)) * materials['albedo'][:, None, None, :] # [HW, envmap_H, envmap_W, 3]\n",
300"\n",
301" if shading == 'blinnphong':\n",
302" assert 'specular_albedo' in materials.keys()\n",
303" specular_albedo = materials['specular_albedo'][:, None, None, :]\n",
304" exponent = materials['specular_exponent'][:, None, None, :]\n",
305"\n",
306" d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)\n",
307" rays_d_norm = -rays_d / jnp.sqrt(d_norm_sq + 1e-10)\n",
308"\n",
309" halfvectors = omega_xyz.reshape(1, envmap_H, envmap_W, 3) + rays_d_norm[:, None, None, :]\n",
310" halfvectors /= (jnp.linalg.norm(halfvectors, axis=-1, keepdims=True) + 1e-10) # [N, envmap_H, envmap_W, 3]\n",
311"\n",
312" lobes += jnp.maximum(0.0, (halfvectors * normals[:, None, None, :]).sum(-1, keepdims=True)) ** exponent * specular_albedo\n",
313"\n",
314" if shading == 'phong':\n",
315" assert 'specular_albedo' in materials.keys()\n",
316" specular_albedo = materials['specular_albedo'][:, None, None, :]\n",
317" exponent = materials['specular_exponent'][:, None, None, :]\n",
318"\n",
319" d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)\n",
320" rays_d_norm = -rays_d / jnp.sqrt(d_norm_sq + 1e-10) # [N, 3]\n",
321"\n",
322" refdirs = 2.0 * (normals * rays_d_norm).sum(-1, keepdims=True) * normals - rays_d_norm # [N, 3]\n",
323" refdirs = refdirs[:, None, None, :]\n",
324"\n",
325" # No need to normalize because ||n|| = 1 and ||d|| = 1, so ||2(n.d)n - d|| = 1.\n",
326" print(\"Not normalizing here (because unnecessary, at least theoretically).\")\n",
327" #refdirs /= (jnp.linalg.norm(refdirs, axis=-1, keepdims=True) + 1e-10) # [N, HW, envmap_H, envmap_W, 3]\n",
328"\n",
329" lobes += jnp.maximum(0.0, (refdirs * omega_xyz.reshape(1, envmap_H, envmap_W, 3)).sum(-1, keepdims=True)) ** exponent * specular_albedo\n",
330" \n",
331" colors = jax.vmap(render_pixel, in_axes=(0, 0, None, None))(normals, lobes, envmap, mask) \n",
332" \n",
333" return colors * alpha"
334]
335},
336{
337"cell_type": "code",
338"execution_count": null,
339"metadata": {
340"id": "dq8Nmn1i9Dnf"
341},
342"outputs": [],
343"source": [
344"def load_img(pth: str, is_16bit: bool=False) -\u003e np.ndarray:\n",
345" \"\"\"Load an image and cast to float32.\"\"\"\n",
346" with utils.open_file(pth, 'rb') as f:\n",
347" if is_16bit:\n",
348" bytes_ = np.asarray(bytearray(f.read()), dtype=np.uint8) # Read bytes\n",
349" image = np.array(\n",
350" cv2.imdecode(bytes_, cv2.IMREAD_UNCHANGED), dtype=np.float32)\n",
351" else:\n",
352" image = np.array(Image.open(f), dtype=np.float32)\n",
353" return image\n",
354"\n",
355"#disp = load_img(os.path.join(data_dir, 'test', 'r_0_disp.tiff'), is_16bit=True)[:, :, :1] / 255.0\n",
356"#plt.imshow(1/disp-1)"
357]
358},
359{
360"cell_type": "code",
361"execution_count": null,
362"metadata": {
363"id": "CWfd1iC3iwmj"
364},
365"outputs": [],
366"source": []
367},
368{
369"cell_type": "code",
370"execution_count": null,
371"metadata": {
372"id": "vnLOEgyNOog9"
373},
374"outputs": [],
375"source": [
376"\"\"\"\n",
377"disp = load_img('{DIRECTORY}/r_4_disp_0029.tif', True) / 65535.0\n",
378"depth = 1.0 / disp[:, :, 0] - 1.0\n",
379"plt.imshow(depth)\n",
380"print(np.nanmin(depth), np.sqrt(4.0 ** 2 + 0.5 ** 2) - 1.0)\n",
381"\"\"\";"
382]
383},
384{
385"cell_type": "code",
386"execution_count": null,
387"metadata": {
388"id": "cdnzd5bFi_4D"
389},
390"outputs": [],
391"source": [
392"Config = configs.Config()\n",
393"Config.dataset_loader = 'Blender'\n",
394"Config.near = 6\n",
395"Config.far = 2\n",
396"Config.factor = 1\n",
397"Config.disp_tiff = True\n",
398"\n",
399"# Force loading disparities and normals\n",
400"Config.compute_disp_metrics = True\n",
401"Config.compute_normal_metrics = True\n",
402"Config.semantic_dir = None\n",
403"import queue\n",
404"import jax\n",
405"import json\n",
406"import os\n",
407"\n",
408"LOCAL_COLMAP_DIR = '/tmp/colmap/'\n",
409"\n",
410"#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_occlusions_srgb_128x128'\n",
411"#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_occlusions_lambertian_srgb_128x128'; Config.disp_tiff = False\n",
412"#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_occlusions_linear_128x128'\n",
413"#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_occlusions_uniform_linear_128x128'\n",
414"#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_farther_occlusions_uniform_linear_128x128'\n",
415"#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_lowres_envmap_farfield_occluder_uniform_linear_128x128'\n",
416"#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_occlusions_lambertian_new_uniform_linear_128x128'\n",
417"data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_farfield_occlusions_lambertian_new_no_self_occ_uniform_linear_128x128'\n",
418"#materials_gt = {'albedo': jnp.ones_like(imgs_gt)*0.38}\n",
419"\n",
420"#data_loader = Blender('test', data_dir, Config)\n",
421"data_loader = Blender('occlusions', data_dir, Config)\n",
422"#data_loader._next_fn = data_loader._next_test"
423]
424},
425{
426"cell_type": "code",
427"execution_count": null,
428"metadata": {
429"id": "sWVplUIuxrdZ"
430},
431"outputs": [],
432"source": [
433"imgs_gt = []\n",
434"normals_gt = []\n",
435"disps_gt = []\n",
436"alpha_gt = []\n",
437"rays_o_ = []\n",
438"rays_d_ = []\n",
439"\n",
440"N_cameras = data_loader.size\n",
441"for i in range(N_cameras):\n",
442" batch = next(data_loader)\n",
443" imgs_gt.append(batch.rgb)\n",
444" normals_gt.append(batch.normals)\n",
445" alpha_gt.append(batch.alphas)\n",
446" disps_gt.append(batch.disps)\n",
447" rays_o_.append(batch.rays.origins)\n",
448" rays_d_.append(batch.rays.directions)\n"
449]
450},
451{
452"cell_type": "code",
453"execution_count": null,
454"metadata": {
455"id": "JtDBdudY1vgH"
456},
457"outputs": [],
458"source": [
459"imgs_gt = jnp.stack(imgs_gt, axis=0).reshape(N_cameras, -1, 3)\n",
460"normals_gt = jnp.stack(normals_gt, axis=0).reshape(N_cameras, -1, 3)\n",
461"alpha_gt = jnp.float32(jnp.stack(alpha_gt, axis=0).reshape(N_cameras, -1, 1) \u003e 0.99)\n",
462"disps_gt = jnp.stack(disps_gt, axis=0)[..., :1].reshape(N_cameras, -1, 1)\n",
463"rays_o_vec = jnp.stack(rays_o_, axis=0).reshape(N_cameras, -1, 3)\n",
464"rays_d_vec = jnp.stack(rays_d_, axis=0).reshape(N_cameras, -1, 3)\n",
465"\n",
466"t_surface_gt = 1.0 / disps_gt - 1.0\n",
467"materials_gt = {'albedo': jnp.ones_like(imgs_gt)*0.15}\n"
468]
469},
470{
471"cell_type": "code",
472"execution_count": null,
473"metadata": {
474"id": "R2DvKhauS1xl"
475},
476"outputs": [],
477"source": [
478"#pts = rays_o_vec + rays_d_vec * t_surface_gt\n",
479"\n",
480"#normals_gt = pts / jnp.linalg.norm(pts, axis=-1, keepdims=True)\n",
481"\n"
482]
483},
484{
485"cell_type": "code",
486"execution_count": null,
487"metadata": {
488"id": "-8jQSjT2RRBx"
489},
490"outputs": [],
491"source": [
492"H, W = data_loader.images[0].shape[:2]\n",
493"print(f\"There are {imgs_gt.shape[0]} images of size {H}x{W}\")"
494]
495},
496{
497"cell_type": "code",
498"execution_count": null,
499"metadata": {
500"id": "C_CDjVlQzlGw"
501},
502"outputs": [],
503"source": [
504"ind = 10\n",
505"plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
506"plt.figure()\n",
507"plt.imshow(normals_gt[ind].reshape(H, W, 3) * 0.5 + 0.5)\n",
508"plt.figure()\n",
509"plt.imshow(np.where(alpha_gt[ind] \u003c 0.99, np.nan, t_surface_gt[0]).reshape(H, W, 1))\n",
510"plt.figure()\n",
511"plt.imshow(alpha_gt[ind].reshape(H, W, 1))\n"
512]
513},
514{
515"cell_type": "code",
516"execution_count": null,
517"metadata": {
518"id": "6Z-jwtNCms5O"
519},
520"outputs": [],
521"source": [
522"plt.imshow(imgs_gt[0].reshape(H, W, 3))\n",
523"plt.figure()\n",
524"plt.imshow(normals_gt[0].reshape(H, W, 3) * 0.5 + 0.5)\n",
525"plt.figure()\n",
526"plt.imshow(t_surface_gt[0].reshape(H, W, 1))\n",
527"plt.figure()\n",
528"plt.imshow(alpha_gt[0].reshape(H, W, 1))\n"
529]
530},
531{
532"cell_type": "code",
533"execution_count": null,
534"metadata": {
535"id": "cR9X0LxU9QTY"
536},
537"outputs": [],
538"source": [
539"rays_d_vec.shape"
540]
541},
542{
543"cell_type": "code",
544"execution_count": null,
545"metadata": {
546"id": "hJaeLucx9LCU"
547},
548"outputs": [],
549"source": [
550"\n",
551"rays_d_r = rays_d_vec.reshape(-1, H, W, 3)\n",
552"\n",
553"occluder_relative_size = 0.1 #0.03 # Ratio of the unit sphere occluder by the occluder\n",
554"th = 1.0 - 2 * occluder_relative_size\n",
555"# Make masks\n",
556"mask_shape = 'circle'\n",
557"if mask_shape == 'circle' or mask_shape == 'two_circles' or mask_shape == 'three_circles':\n",
558" sdfs_gt = th - jnp.sum(-omega_xyz[None, :, :] * rays_d_r[:, H//2, W//2, :][:, None, :], axis=-1)\n",
559"if mask_shape == 'two_circles' or mask_shape == 'three_circles':\n",
560" camera_dirs = rays_d_r[:, H//2, W//2, :]\n",
561" up_dirs = rays_d_r[:, H//2+1, W//2, :] - camera_dirs\n",
562" up_dirs = up_dirs / jnp.linalg.norm(up_dirs, axis=-1, keepdims=True)\n",
563" rot_dirs = jnp.cross(up_dirs, camera_dirs, axis=-1) # direction of rotation is camera direction cross up direction\n",
564"\n",
565" rotation_angle = jnp.pi / 7\n",
566" dir2 = camera_dirs * jnp.cos(rotation_angle) + jnp.cross(rot_dirs, camera_dirs, axis=-1) * jnp.sin(rotation_angle) + rot_dirs * (rot_dirs * camera_dirs).sum(-1, keepdims=True) * (1 - jnp.cos(rotation_angle))\n",
567" \n",
568" sdfs_gt = jnp.minimum(sdfs_gt, th - jnp.sum(-omega_xyz[None, :, :] * dir2[:, None, :], axis=-1)) # Occluder is aligned with the camera\n",
569"if mask_shape == 'three_circles':\n",
570" camera_dirs = rays_d_r[:, H//2, W//2, :]\n",
571" up_dirs = rays_d_r[:, H//2+1, W//2, :] - camera_dirs\n",
572" up_dirs = up_dirs / jnp.linalg.norm(up_dirs, axis=-1, keepdims=True)\n",
573" rot_dirs = up_dirs # direction of rotation is up direction\n",
574"\n",
575" rotation_angle = jnp.pi / 6\n",
576" dir2 = camera_dirs * jnp.cos(rotation_angle) + jnp.cross(rot_dirs, camera_dirs, axis=-1) * jnp.sin(rotation_angle) + rot_dirs * (rot_dirs * camera_dirs).sum(-1, keepdims=True) * (1 - jnp.cos(rotation_angle))\n",
577" sdfs_gt = jnp.minimum(sdfs_gt, th - jnp.sum(-omega_xyz[None, :, :] * dir2[:, None, :], axis=-1)) # Occluder is aligned with the camera\n",
578"\n",
579"sdfs_gt = sdfs_gt.reshape(-1, envmap_H, envmap_W)\n",
580"masks_gt = jnp.float32(sdfs_gt \u003e 0.0)\n",
581"\n"
582]
583},
584{
585"cell_type": "code",
586"execution_count": null,
587"metadata": {
588"id": "0nRjSoAD9pBp"
589},
590"outputs": [],
591"source": [
592"shading = 'lambertian'\n",
593"exposure = 1.0\n",
594"\n",
595"# Render dataset\n",
596"num_devices = jax.local_device_count()\n",
597"imgs_gt = []\n",
598"\n",
599"render_gt_partial = functools.partial(render, shading=shading)\n",
600"for i in range(N_cameras // num_devices):\n",
601" i0 = i * num_devices\n",
602" i1 = i0 + num_devices\n",
603"\n",
604" imgs_gt_ = jax.pmap(render_gt_partial, in_axes=(None, 0, 0, 0, 0, 0))( # here pmap is faster\n",
605" envmap_gt * exposure,\n",
606" masks_gt[i0:i1],\n",
607" jax.tree_map(lambda x: x[i0:i1], materials_gt),\n",
608" normals_gt[i0:i1],\n",
609" rays_d_vec.reshape(-1, H*W, 3)[i0:i1],\n",
610" alpha_gt[i0:i1],\n",
611" )\n",
612" imgs_gt.append(imgs_gt_)\n",
613"\n",
614"imgs_gt = linear_to_srgb(jnp.concatenate(imgs_gt, axis=0))\n",
615"\n",
616"print(jnp.nanmax(imgs_gt))\n"
617]
618},
619{
620"cell_type": "code",
621"execution_count": null,
622"metadata": {
623"id": "NASFYPonhri0"
624},
625"outputs": [],
626"source": [
627"ind = 50\n",
628"plt.figure()\n",
629"plt.imshow(imgs_gt[ind].reshape(H, W, 3))"
630]
631},
632{
633"cell_type": "code",
634"execution_count": null,
635"metadata": {
636"id": "brlbzchGscOQ"
637},
638"outputs": [],
639"source": [
640"def mfbrdf_map(viewdir, normal, albedo, roughness, eps=1e-15):\n",
641" half_vecs = omega_xyz + viewdir[None, :]\n",
642" half_vecs /= (jnp.linalg.norm(half_vecs, axis=-1, keepdims=True) + eps)\n",
643"\n",
644" n_dot_v = jnp.abs(jnp.sum(viewdir * normal)) + 1e-5\n",
645" n_dot_l = jnp.maximum(jnp.sum(omega_xyz * normal[None, :], axis=-1), 0.0)\n",
646" n_dot_h = jnp.maximum(jnp.sum(normal[None, :] * half_vecs, axis=-1), 0.0)\n",
647" l_dot_h = jnp.maximum(jnp.sum(omega_xyz * half_vecs, axis=-1), 0.0)\n",
648"\n",
649" print(n_dot_v.shape, n_dot_l.shape)\n",
650"\n",
651" F_0 = 0.04\n",
652" a = roughness**2\n",
653"\n",
654" D = a / (jnp.pi * ((a - 1.0) * n_dot_h ** 2 + 1.)**2)\n",
655" F = F_0 + (1. - F_0) * jnp.power(1. - l_dot_h, 5)\n",
656" #V = 0.5 / ((n_dot_v * jnp.sqrt((-1. * n_dot_l * a + n_dot_l) * n_dot_l + a)) + (n_dot_l * jnp.sqrt((n_dot_v * (1 - a) * n_dot_v + a))))\n",
657" V = 0.5 / ((n_dot_v * jnp.sqrt((-1. * n_dot_l * a + n_dot_l) * n_dot_l + a)) + (n_dot_l * jnp.sqrt((-1. * n_dot_v * a + n_dot_v) * n_dot_v + a)))\n",
658" brdf = D * F * V\n",
659"\n",
660" print(brdf.shape)\n",
661" brdf = brdf + (1. - F) * albedo / jnp.pi\n",
662" #brdf = jnp.reshape(brdf, [mapres[0], mapres[1], 3])\n",
663" return brdf\n",
664"\n",
665"brdf = mfbrdf_map(jnp.array([0.0, 0.0, 1.0]), jnp.array([0.0, 1.0, 1.0])/jnp.sqrt(2), 0.7, 0.7)\n",
666"plt.imshow(brdf.reshape(envmap_H, envmap_W))\n",
667"plt.colorbar()"
668]
669},
670{
671"cell_type": "code",
672"execution_count": null,
673"metadata": {
674"id": "9_J8f4Ow0SXG"
675},
676"outputs": [],
677"source": [
678"num_devices = jax.local_device_count()\n",
679"\n",
680"shading = 'lambertian'\n",
681"exposure = 1.0\n",
682"\n",
683"render_partial = functools.partial(render, shading=shading)\n",
684"\n",
685"gt_list = []\n",
686"#gt_list = ['materials']\n",
687"#gt_list = ['masks', 'materials']\n",
688"#gt_list = ['sdfs', 'materials']\n",
689"#gt_list = ['envmap', 'materials']\n",
690"#gt_list = ['envmap', 'masks']\n",
691"#gt_list = ['masks']\n",
692"#gt_list = ['masks', 'materials', 'envmap']\n",
693"\n",
694"class MLP(nn.Module):\n",
695" features: Sequence[int]\n",
696"\n",
697" @nn.compact\n",
698" def __call__(self, x, y=None):\n",
699" if y is not None:\n",
700" x = jnp.concatenate([x, y], axis=-1)\n",
701" for feat in self.features[:-1]:\n",
702" x = nn.relu(nn.Dense(feat)(x))\n",
703" x = nn.Dense(self.features[-1])(x)\n",
704" return x\n",
705"\n",
706"def get_pyramid_params(rng, num_pyramids, height, width, pyramid_num_scales, pyramid_resize_scale, global_std, global_bias):\n",
707" pyramid_params = []\n",
708" for i in range(pyramid_num_scales):\n",
709" gsh, gsw = [sz // pyramid_resize_scale ** i for sz in [height, width]]\n",
710" key, rng = jax.random.split(rng)\n",
711" features = jax.random.normal(key, (num_pyramids, gsh, gsw)) * global_std + global_bias\n",
712" pyramid_params.append(features)\n",
713" return pyramid_params\n",
714"\n",
715"\n",
716"def pyramids_to_imgs(pyramid_params, pyramid_mult, img_inds):\n",
717" def pyramid_to_img(pyramid_params, pyramid_mult):\n",
718" acc_val = pyramid_params[-1]\n",
719" for i, curr_val in enumerate(pyramid_params[-2::-1], start=1):\n",
720" # upsample\n",
721" acc_val = jax.image.resize(acc_val * pyramid_mult, shape=curr_val.shape, method='linear')\n",
722" \n",
723" # accumulate\n",
724" acc_val += curr_val\n",
725" return acc_val\n",
726"\n",
727" # Select all pyramid parameters at given indices\n",
728" sub_pyramid_params = jax.tree_map(lambda t: t[img_inds], pyramid_params)\n",
729" return jax.vmap(pyramid_to_img, in_axes=(0, None))(sub_pyramid_params, pyramid_mult)\n",
730"\n",
731"\n",
732"def grad_norm_spherical(dirs, grad):\n",
733" \"\"\"\n",
734" Compute gradient norm restricted to the sphere.\n",
735" Assume dim 0 is elevation and 1 is azimuth.\n",
736"\n",
737" dirs: (N, 3) array of directions on the sphere\n",
738" grad: (N, 3) array of Cartesian gradients of points on dirs\n",
739" \"\"\"\n",
740"\n",
741" norm_spherical = grad - (dirs * grad).sum(axis=-1, keepdims=True) * dirs\n",
742" return jnp.sqrt((norm_spherical ** 2).sum(-1) + 1e-5)\n",
743"\n",
744"\n",
745"\n",
746"rays_o_r = rays_o_vec.reshape(-1, H*W, 3)\n",
747"rays_d_r = rays_d_vec.reshape(-1, H*W, 3)\n",
748"\n",
749"\n",
750"\n",
751"append_identity = True\n",
752"def posenc(x, L_encoding):\n",
753" if L_encoding \u003c= 0:\n",
754" return x\n",
755" else:\n",
756" scales = 2**jnp.arange(L_encoding)\n",
757" #shape = x.shape[:-1] + (-1,)\n",
758" #scaled_x = jnp.reshape((x[..., None, :] * scales[:, None]), shape)\n",
759"\n",
760" #four_feat = jnp.sin(\n",
761" # jnp.concatenate([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1))\n",
762" shape = x.shape[:-1] + (-1,)\n",
763" scaled_x = x[..., None, :] * scales[:, None] # [..., L, D]\n",
764"\n",
765" four_feat = jnp.sin(\n",
766" jnp.stack([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1)) # [..., L, D, 2]\n",
767"\n",
768" four_feat = jnp.reshape(four_feat / scales[:, None, None], shape)\n",
769" print(\"Using Lipschitz posenc\")\n",
770" if append_identity:\n",
771" return jnp.concatenate([x] + [four_feat], axis=-1)\n",
772" else:\n",
773" return four_feat\n",
774"\n",
775"\n",
776"def params_to_sdf(params_sdf, img_inds):\n",
777" if sdf_representation == 'mlp':\n",
778" sdf, sdf_grad = jax.vmap(jax.value_and_grad(lambda x, y: sdf_mlp.apply(params_sdf, x, y)[0]))(\n",
779" posenc(omega_xyz[None, :, :], L_encoding_sdf).repeat(img_inds.shape[0], 0).reshape(-1, mlp_input_features),\n",
780" rays_o_vec[img_inds, 0, 0, :][:, None, :].repeat(envmap_H * envmap_W, 1).reshape(-1, 3)\n",
781" )\n",
782" elif sdf_representation == 'grid':\n",
783" sdf = params_sdf[img_inds[:, None], ...]\n",
784" else:\n",
785" sdf = pyramids_to_imgs(params_sdf, pyramid_mult, img_inds)\n",
786" return sdf\n",
787"\n",
788"\n",
789"def sdf_to_mask(x, width, curve='sigmoid'):\n",
790" if curve == 'sigmoid':\n",
791" return jax.nn.sigmoid(x * width)\n",
792" elif curve == 'laplace_cdf':\n",
793" return 0.5 + 0.5 * jnp.sign(x) * (1.0 - jnp.exp(-jnp.abs(x) * width))\n",
794" else:\n",
795" raise NotImplementedError('Only sigmoid and laplace_cdf for now.')\n",
796" \n",
797"def params_to_materials(params_materials, pts):\n",
798" mlp_res = material_mlp.apply(params_materials, posenc(pts, L_encoding_materials))\n",
799" materials = {}\n",
800" materials['albedo'] = jax.nn.sigmoid(mlp_res[..., 0:3])\n",
801" if shading in ['phong', 'blinnphong']:\n",
802" #materials['specular_albedo'] = 20.0 * jax.nn.sigmoid(mlp_res[..., 3:6])\n",
803" if is_dielectric:\n",
804" materials['specular_albedo'] = jax.nn.softplus(mlp_res[..., 3:4])\n",
805" materials['specular_exponent'] = jax.nn.softplus(mlp_res[..., 4:5])\n",
806" else:\n",
807" materials['specular_albedo'] = jax.nn.softplus(mlp_res[..., 3:6])\n",
808" materials['specular_exponent'] = jax.nn.softplus(mlp_res[..., 6:7])\n",
809"\n",
810" return materials\n",
811"\n",
812"R = jnp.array([[ 0.0, 1.0, 0.0],\n",
813" [-1.0, 0.0, 0.0],\n",
814" [ 0.0, 0.0, 1.0]])\n",
815"\n",
816"@jax.jit\n",
817"def get_loss(params_envmap, params_sdf, params_materials, gt, spatial_inds, img_inds, i, rng):\n",
818" rng, key = jax.random.split(rng)\n",
819"\n",
820" if 'envmap' in gt_list:\n",
821" envmap = envmap_gt * exposure\n",
822" else:\n",
823" envmap = params_to_envmap(params_envmap)\n",
824"\n",
825" normals = normals_gt[img_inds[:, None], spatial_inds, ...]\n",
826" t_surface = t_surface_gt[img_inds[:, None], spatial_inds, ...]\n",
827" alpha = alpha_gt[img_inds[:, None], spatial_inds, ...]\n",
828" rays_d = rays_d_r[img_inds[:, None], spatial_inds, ...]\n",
829" rays_o = rays_o_r[img_inds[:, None], spatial_inds, ...]\n",
830"\n",
831" if 'materials' in gt_list:\n",
832" materials = jax.tree_map(lambda x: x[img_inds[:, None], spatial_inds[img_inds], ...], materials_gt)\n",
833" else:\n",
834" pts = rays_o + rays_d * t_surface\n",
835" materials = params_to_materials(params_materials, pts)\n",
836"\n",
837" sdf = params_to_sdf(params_sdf, img_inds)\n",
838" sdf = sdf.reshape(img_inds.shape[0], envmap_H * envmap_W)\n",
839" #mask_width = 200.0 * (i / num_iters) + 10.0\n",
840" #masks = sdf_to_mask(sdf.reshape(-1, envmap_H, envmap_W), mask_width, 'sigmoid')\n",
841"\n",
842" sdf_curve = 'sigmoid'\n",
843"\n",
844" if 'masks' in gt_list or 'sdfs' in gt_list:\n",
845" if 'masks' in gt_list:\n",
846" masks = masks_gt[img_inds]\n",
847" else:\n",
848" masks = sdf_to_mask(sdfs_gt[img_inds], mask_width, sdf_curve)\n",
849" #masks = masks * 0.0 + 1.0\n",
850" #print(\"Setting masks to 1!!!!!!!!!!!!!!!!!\")\n",
851" eikonal_loss = 0.0\n",
852" length_loss = 0.0\n",
853" mask_area_loss = 0.0\n",
854" else:\n",
855" #mask_width = 10.0 #200.0 * (i / num_iters) + 10.0\n",
856" #mask_width = 20.0 * (i / num_iters) + 10.0\n",
857" #mask_width = 0.1\n",
858" #mask_width = 10.0 #150.0 * (i / num_iters) + 6.0\n",
859" #mask_width = 0.1 * jnp.exp(jnp.log(100) * i / num_iters)\n",
860" if straight_through_mode == 'hard':\n",
861" masks = sdf_to_mask(sdf, mask_width, sdf_curve)\n",
862" masks = (masks + jax.lax.stop_gradient(jnp.float32(sdf \u003e 0.0) - masks))\n",
863" elif straight_through_mode == 'soft':\n",
864" print(\"I think this might actually be 'none' straight_through_mode instead of 'soft' because mask_width is 0.1\")\n",
865" soft_masks = sdf_to_mask(sdf, 0.1, sdf_curve)\n",
866" hard_masks = sdf_to_mask(sdf, mask_width, sdf_curve)\n",
867" # Define masks with value of `hard_masks` but gradients of `soft_masks`\n",
868" masks = (soft_masks + jax.lax.stop_gradient(hard_masks - soft_masks))\n",
869" elif straight_through_mode == 'none':\n",
870" masks = sdf_to_mask(sdf, mask_width, sdf_curve)\n",
871"\n",
872"\n",
873"\n",
874" if sdf_representation == 'mlp':\n",
875" # TODO: Only compute grad w.r.t. x, not w.r.t. other posenc components\n",
876" sdf_grad = sdf_grad[..., :3].reshape(img_inds.shape[0], envmap_H * envmap_W, 3)\n",
877" sdf_grad_norm = grad_norm_spherical(omega_xyz, sdf_grad.reshape(img_inds.shape[0], envmap_H * envmap_W, 3))\n",
878" eikonal_loss = (jnp.sin(omega_theta) * (sdf_grad_norm - 1) ** 2).sum() * dtheta_dphi\n",
879" else:\n",
880" eikonal_loss = 0.0\n",
881"\n",
882" # Compute entropy assuming mask = sigmoid(sdf)\n",
883" entropy_loss = jax.nn.softplus(-sdf) + sdf * (1.0 - jax.nn.sigmoid(sdf))\n",
884" mask_area_loss = 1.0 - masks # Try to make the occluder as small as possible\n",
885" #delta_sdf = jax.vmap(jax.grad(lambda x: sdf_to_mask(x, mask_width, sdf_curve)))(sdf.flatten()).reshape(img_inds.shape[0], envmap_H * envmap_W)\n",
886"\n",
887" #length_loss = (jnp.sin(omega_theta) * sdf_grad_norm * delta_sdf).sum() * dtheta_dphi\n",
888"\n",
889" res = jax.vmap(render_partial, in_axes=(None, 0, 0, 0, 0, 0))(\n",
890" envmap,\n",
891" masks.reshape(img_inds.shape[0], envmap_H, envmap_W),\n",
892" materials,\n",
893" normals,\n",
894" rays_d,\n",
895" alpha,\n",
896" )\n",
897"\n",
898" #diff = gt[img_inds[:, None], spatial_inds[img_inds], :] - res\n",
899" diff = gt[img_inds[:, None], spatial_inds, :] - linear_to_srgb(res)\n",
900" #loss_per_element = (diff ** 2).sum(-1)\n",
901" #loss_per_element = jnp.abs(diff).sum(-1)\n",
902" if False:\n",
903" p = 2 - 1.5 * i / num_iters\n",
904" data_loss = jnp.power((jnp.abs(diff + 1e-10) ** p).sum(), 1/p)\n",
905" print(\"Using graduated nonconvexity in the loss\")\n",
906" else:\n",
907" data_loss = (jnp.abs(diff) ** 2).sum()\n",
908" print(\"Using L2 loss\")\n",
909"\n",
910" data_loss = data_loss / img_inds.shape[0] / spatial_inds.shape[-1]\n",
911"\n",
912" loss = data_loss\n",
913" loss += 0.1 * eikonal_loss / img_inds.shape[0]\n",
914" loss += 1e-5 * (mask_area_loss * jnp.sin(omega_theta)).sum() * dtheta_dphi / img_inds.shape[0] / 4.0 / jnp.pi\n",
915" #loss += 1e-1 * ((jnp.abs(params_envmap) ** 2) * ell.flatten()[:, None, None]).mean()\n",
916" #loss += 0.01 * (entropy_loss * jnp.sin(omega_theta)).sum() * dtheta_dphi / img_inds.shape[0] / 4.0 / jnp.pi\n",
917" #loss += 0.01 * length_loss / img_inds.shape[0]\n",
918"\n",
919" # Environment map TV loss\n",
920" loss += 1e-7 * (((envmap[:, 1:] - envmap[:, :-1]) ** 2).sum() + ((envmap[1:, :] - envmap[:-1, :]) ** 2).sum()) * dtheta_dphi / 4.0 / jnp.pi\n",
921"\n",
922" return loss, (data_loss, eikonal_loss)\n",
923"\n",
924"def safe_exp(x):\n",
925" return jnp.exp(jnp.minimum(x, 80.0))\n",
926"\n",
927"def tonemap_and_clip(x):\n",
928" return np.clip(linear_to_srgb(x), 0.0, 1.0)\n",
929"\n",
930"def params_to_envmap(params_envmap):\n",
931" #envmap = jax.nn.sigmoid(params_envmap)\n",
932" #envmap = jax.nn.softplus(params_envmap)\n",
933" if envmap_representation == 'SH':\n",
934" #params_envmap = jnp.where(ell.flatten()[:, None, None] \u003c 5, params_envmap, 0.0)\n",
935" envmap = jax.nn.softplus(jax.vmap(isht, in_axes=-1, out_axes=-1)(params_envmap))\n",
936" else:\n",
937" envmap = safe_exp(params_envmap)\n",
938" return envmap\n",
939"\n",
940"\n",
941"@jax.jit\n",
942"def update_params(i, rng, state_envmap, state_sdf, state_materials, gt, spatial_inds, img_inds):\n",
943" params_envmap = get_params_envmap(state_envmap)\n",
944" params_sdf = get_params_sdf(state_sdf)\n",
945" params_materials = get_params_materials(state_materials)\n",
946"\n",
947" (loss, (data_loss, eikonal_loss)), g = jax.value_and_grad(get_loss, argnums=(0, 1, 2), has_aux=True)(params_envmap, params_sdf, params_materials,\n",
948" gt, spatial_inds, img_inds, i, rng)\n",
949"\n",
950" grad_envmap = jax.lax.pmean(g[0], axis_name='batch')\n",
951" grad_sdf = jax.lax.pmean(g[1], axis_name='batch')\n",
952" grad_materials = jax.lax.pmean(g[2], axis_name='batch')\n",
953" eikonal_loss = jax.lax.pmean(eikonal_loss, axis_name='batch')\n",
954" data_loss = jax.lax.pmean(data_loss, axis_name='batch')\n",
955" loss = jax.lax.pmean(loss, axis_name='batch')\n",
956"\n",
957" return update_envmap(i, grad_envmap, state_envmap), update_sdf(i, grad_sdf, state_sdf), update_materials(i, grad_materials, state_materials), loss, data_loss, eikonal_loss\n",
958"\n",
959"\n",
960"#slow_optimization_mode = ('masks' not in gt_list and 'sdfs' not in gt_list) or 'materials' not in gt_list\n",
961"#if slow_optimization_mode:\n",
962"# print(\"Slow optimization\")\n",
963"#else:\n",
964"# print(\"Fast optimization\")\n",
965"\n",
966"num_iters = 150000 #50000 if slow_optimization_mode else 10000\n",
967"#num_iters = 50000\n",
968"straight_through_mode = 'soft'\n",
969"assert straight_through_mode in ['none', 'hard', 'soft']\n",
970"\n",
971"# TODO:\n",
972"# 1. It looks like using straight-through on the masks (with constant width 10) makes them be a little off,\n",
973"# but improves the envmap (making it a little noisier because of the bad masks). Why?\n",
974"# 2. \n",
975"\n",
976"envmap_representation = 'direct'\n",
977"if envmap_representation == 'SH':\n",
978" #params_envmap = (jax.random.uniform(jax.random.PRNGKey(0), shape=(envmap_H, envmap_H, 3)) - 0.5) * 0.01\n",
979" params_envmap = (jax.random.uniform(jax.random.PRNGKey(0), shape=(envmap_H, envmap_H, 3, 2)) - 0.5) / (1 + ell.flatten()[:, None, None, None]) * 0.01\n",
980" params_envmap = params_envmap[..., 0] + 1j * params_envmap[..., 1]\n",
981" init_lr_envmap = 0.0003 #if slow_optimization_mode else 0.001\n",
982"\n",
983"elif envmap_representation == 'direct':\n",
984" params_envmap = (jax.random.uniform(jax.random.PRNGKey(0), shape=(envmap_H, envmap_W, 3)) - 0.5) * 0.1\n",
985" #print(\"TODO: Get rid of this annoying -4.0\")\n",
986"\n",
987" init_lr_envmap = 0.003 #if slow_optimization_mode else 0.03\n",
988"else:\n",
989" raise ValueError('')\n",
990" \n",
991"init_envmap, update_envmap, get_params_envmap = jax.experimental.optimizers.adam(init_lr_envmap)\n",
992"state_envmap = init_envmap(params_envmap)\n",
993"\n",
994"\n",
995"sdf_representation = 'pyramid' # 'pyramid', 'mlp', 'pyramid'\n",
996"\n",
997"if sdf_representation == 'mlp':\n",
998" L_encoding_sdf = 0 # 4\n",
999" mlp_input_features = 3 + 6 * L_encoding_sdf\n",
1000"\n",
1001" sdf_mlp = MLP([128]*4 + [1])\n",
1002"\n",
1003" params_sdf = sdf_mlp.init(jax.random.PRNGKey(0),\n",
1004" np.zeros([1, mlp_input_features]),\n",
1005" np.zeros([1, 3]))\n",
1006" init_lr_sdf = 0.0001\n",
1007"elif sdf_representation == 'pyramid':\n",
1008" pyramid_num_scales = 5\n",
1009" pyramid_resize_scale = 2\n",
1010" pyramid_mult = 2.0\n",
1011" global_std = 1.0 #0.1\n",
1012" global_bias = 4.0\n",
1013"\n",
1014" rng = jax.random.PRNGKey(0)\n",
1015" params_sdf = get_pyramid_params(rng, N_cameras, envmap_H, envmap_W, pyramid_num_scales, pyramid_resize_scale, global_std, global_bias)\n",
1016" init_lr_sdf = 0.1 #* 100\n",
1017" mask_width = 0.1\n",
1018"\n",
1019"elif sdf_representation == 'grid':\n",
1020" params_sdf = jax.random.normal(jax.random.PRNGKey(111), shape=(N_cameras, envmap_H, envmap_W))\n",
1021" init_lr_sdf = 0.01\n",
1022"else:\n",
1023" raise ValueError('')\n",
1024"\n",
1025"#for _ in range(20):\n",
1026"# print(\"TODO: Use xmanager to optimize: learning rates, biases, global std for params_sdf in pyramid mode, number of iterations (longer!), etc.\")\n",
1027"# print(\"TODO: Find out what happens if for masks we use the ground truth SDFs passed through a 0.1 sigmoid, instead of the GT masks. This is the best case scenario when using such a soft sigmoid!\")\n",
1028"# print(\"TODO: Replace initialization as soft ~0.5ish masks and dark envmap with good init. of envmap and ~1 masks (no occluders). Currently things just go there anyway...\")\n",
1029"\n",
1030"init_sdf, update_sdf, get_params_sdf = jax.experimental.optimizers.adam(init_lr_sdf)\n",
1031"state_sdf = init_sdf(params_sdf)\n",
1032"\n",
1033"# Initialize material MLP\n",
1034"L_encoding_materials = 4\n",
1035"mlp_input_features = 3 + 6 * L_encoding_materials\n",
1036"\n",
1037"is_dielectric = True\n",
1038"num_components = 3 # 3 for diffuse\n",
1039"if shading in ['phong', 'blinnphong']:\n",
1040" if is_dielectric:\n",
1041" num_components += 2 # 1 for specular albedo, 1 for exponent\n",
1042" else:\n",
1043" num_components += 4 # 3 for specular albedo, 1 for exponent\n",
1044"material_mlp = MLP([128]*4 + [num_components])\n",
1045"\n",
1046"params_materials = material_mlp.init(jax.random.PRNGKey(0),\n",
1047" np.zeros([1, mlp_input_features]))\n",
1048"init_lr_materials = 0.003\n",
1049"init_materials, update_materials, get_params_materials = jax.experimental.optimizers.adam(init_lr_materials)\n",
1050"state_materials = init_materials(params_materials)\n",
1051"\n",
1052"\n",
1053"np_rng = np.random.default_rng(12345)\n",
1054"jax_rng = jax.random.PRNGKey(3948)\n",
1055"\n",
1056"spatial_batch_size = 64\n",
1057"\n",
1058"img_batch_size = N_cameras #1024\n",
1059"#img_batch_size = 64\n",
1060"losses = []\n",
1061"data_losses = []\n",
1062"eikonal_losses = []\n",
1063"envmap_psnrs = []\n",
1064"tonemapped_envmap_psnrs = []\n",
1065"envmaps = []\n",
1066"\n",
1067"t = 0.0\n",
1068"training_progress_bar = ProgressBar()\n",
1069"training_progress_bar.Publish()\n",
1070"\n",
1071"replicated_state_envmap = flax.jax_utils.replicate(state_envmap)\n",
1072"replicated_state_sdf = flax.jax_utils.replicate(state_sdf)\n",
1073"replicated_state_materials = flax.jax_utils.replicate(state_materials)\n",
1074"replicated_imgs_gt = flax.jax_utils.replicate(imgs_gt)\n",
1075"\n",
1076"\n",
1077"for iteration in range(num_iters):\n",
1078" #t0 = time.time()\n",
1079" # Generate B1 image indices\n",
1080" img_inds = np_rng.choice(imgs_gt.shape[0], size=img_batch_size, replace=False) #* 0\n",
1081" # Now generate B2 pixel indices for each image. The total batch size is B1 * B2.\n",
1082"\n",
1083" keys = jax.random.split(jax_rng, num=img_batch_size+1)\n",
1084" jax_rng, keys = keys[0], keys[1:]\n",
1085" spatial_inds = jax.vmap(jax.random.choice, in_axes=(0, None, None, None, 0))(keys, H*W, (spatial_batch_size,), False, alpha_gt[img_inds, :, 0])\n",
1086"\n",
1087" assert jnp.all(alpha_gt[img_inds[:, None], spatial_inds, :] \u003e 0.99)\n",
1088"\n",
1089" replicated_state_envmap, replicated_state_sdf, replicated_state_materials, loss, data_loss, eikonal_loss = jax.pmap(update_params, in_axes=(None, None, 0, 0, 0, 0, 0, 0), axis_name='batch')(\n",
1090" iteration,\n",
1091" jax_rng,\n",
1092" replicated_state_envmap,\n",
1093" replicated_state_sdf,\n",
1094" replicated_state_materials,\n",
1095" replicated_imgs_gt,\n",
1096" spatial_inds.reshape(num_devices, -1, spatial_batch_size),\n",
1097" img_inds.reshape(num_devices, -1)\n",
1098" )\n",
1099"\n",
1100" if iteration % 100 == 0 or iteration == num_iters - 1:\n",
1101" envmap = params_to_envmap(get_params_envmap(replicated_state_envmap)[0])\n",
1102" envmaps.append(envmap)\n",
1103" mse = (jnp.sin(omega_theta)[:, None] * (exposure * envmap_gt - envmap).reshape(-1, 3) ** 2).sum() * dtheta_dphi / 4.0 / jnp.pi / 3.0\n",
1104" envmap_psnrs.append(mse_to_psnr(mse))\n",
1105" mse = (jnp.sin(omega_theta)[:, None] * (tonemap_and_clip(exposure * envmap_gt) - tonemap_and_clip(envmap)).reshape(-1, 3) ** 2).sum() * dtheta_dphi / 4.0 / jnp.pi / 3.0\n",
1106" tonemapped_envmap_psnrs.append(mse_to_psnr(mse))\n",
1107"\n",
1108"\n",
1109" losses.append(loss[0])\n",
1110" data_losses.append(data_loss[0])\n",
1111" eikonal_losses.append(eikonal_loss[0])\n",
1112"\n",
1113" if iteration in [500] or iteration % 1000 == 0 and iteration \u003e 0:\n",
1114" clear_output(wait=True)\n",
1115" training_progress_bar.Publish()\n",
1116" training_progress_bar.SetProgress(100.0 * (iteration + 1) / num_iters)\n",
1117"\n",
1118" plt.figure()\n",
1119" plt.semilogy(data_losses)\n",
1120" plt.semilogy(eikonal_losses)\n",
1121"\n",
1122" # Plot envmaps\n",
1123" plt.figure(figsize=[16, 8])\n",
1124" plt.subplot(221)\n",
1125" envmap = params_to_envmap(get_params_envmap(flax.jax_utils.unreplicate(replicated_state_envmap)))\n",
1126" plt.imshow(envmap)\n",
1127" plt.axis('off')\n",
1128" plt.subplot(222)\n",
1129" plt.imshow(exposure * envmap_gt)\n",
1130" plt.axis('off')\n",
1131" plt.subplot(223)\n",
1132" plt.imshow(linear_to_srgb(params_to_envmap(get_params_envmap(flax.jax_utils.unreplicate(replicated_state_envmap)))))\n",
1133" plt.axis('off')\n",
1134" plt.subplot(224)\n",
1135" plt.imshow(linear_to_srgb(exposure * envmap_gt))\n",
1136" plt.axis('off')\n",
1137"\n",
1138" # Plot PSNRs\n",
1139" plt.figure(figsize=[8, 8])\n",
1140" plt.subplot(121)\n",
1141" plt.plot(np.linspace(0, iteration, len(envmap_psnrs)), envmap_psnrs)\n",
1142" plt.subplot(122)\n",
1143" plt.plot(np.linspace(0, iteration, len(envmap_psnrs)), tonemapped_envmap_psnrs)\n",
1144"\n",
1145" # Plot masks\n",
1146" num_rows = 6\n",
1147" cameras_to_plot = [int(x) for x in np.linspace(0, N_cameras-1, num_rows)]\n",
1148" sdf_params = get_params_sdf(flax.jax_utils.unreplicate(replicated_state_sdf))\n",
1149" sdfs = params_to_sdf(sdf_params, jnp.array(cameras_to_plot))\n",
1150"\n",
1151" plt.figure(figsize=[15, 12])\n",
1152" for row, i in enumerate(cameras_to_plot):\n",
1153" sdf = sdfs[row].reshape(envmap_H, envmap_W)\n",
1154" for col, img_to_plot in enumerate([sdf, sdf_to_mask(sdf, mask_width), sdf \u003e 0]):\n",
1155" plt.subplot(num_rows, 3, row * 3 + col + 1)\n",
1156" if img_to_plot.shape[-1] != 3:\n",
1157" if col == 0:\n",
1158" plt.imshow(img_to_plot, cmap='gray')\n",
1159" else:\n",
1160" plt.imshow(img_to_plot, cmap='gray', vmin=0.0, vmax=1.0)\n",
1161" else:\n",
1162" plt.imshow(img_to_plot)\n",
1163" plt.axis('off')\n",
1164" plt.title(f'{i}')\n",
1165"\n",
1166"\n",
1167" # Plot materials\n",
1168" ind = 10\n",
1169" pts = rays_o_r[ind] + t_surface_gt[ind] * rays_d_r[ind]\n",
1170" materials = params_to_materials(get_params_materials(flax.jax_utils.unreplicate(replicated_state_materials)), pts)\n",
1171" for k in materials.keys():\n",
1172" plt.figure()\n",
1173" plt.imshow(materials[k].reshape(H, W, -1))\n",
1174" plt.axis('off')\n",
1175" \n",
1176" # Plot rendered image and GT image\n",
1177" sdf = params_to_sdf(sdf_params, jnp.array([ind]))\n",
1178" mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1179" if 'envmap' in gt_list:\n",
1180" envmap = exposure * envmap_gt\n",
1181" if 'mask' in gt_list:\n",
1182" mask = masks_gt[ind]\n",
1183" if 'materials' in gt_list:\n",
1184" materials = jax.tree_map(lambda x: x[ind], materials_gt)\n",
1185" res = linear_to_srgb(render_partial(envmap, mask[0], materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))\n",
1186" plt.figure()\n",
1187" plt.subplot(121)\n",
1188" plt.imshow(res)\n",
1189" plt.axis('off')\n",
1190" plt.subplot(122)\n",
1191" plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1192" plt.axis('off')\n",
1193" plt.show()\n",
1194"\n"
1195]
1196},
1197{
1198"cell_type": "code",
1199"execution_count": null,
1200"metadata": {
1201"id": "WyAaBUBblCpR"
1202},
1203"outputs": [],
1204"source": [
1205"plt.imshow(jnp.log(1e-5 + params_to_envmap(get_params_envmap(flax.jax_utils.unreplicate(replicated_state_envmap)))))\n",
1206"plt.figure()\n",
1207"plt.imshow(jnp.log(1e-5 + exposure * envmap_gt))\n"
1208]
1209},
1210{
1211"cell_type": "code",
1212"execution_count": null,
1213"metadata": {
1214"id": "v81Xb-E5T_oR"
1215},
1216"outputs": [],
1217"source": [
1218"ind = 3\n",
1219"pts = rays_o_r[ind] + t_surface_gt[ind] * rays_d_r[ind]\n",
1220"materials = params_to_materials(get_params_materials(flax.jax_utils.unreplicate(replicated_state_materials)), pts)\n",
1221"\n",
1222"sdf = params_to_sdf(sdf_params, jnp.array([ind]))\n",
1223"mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1224"\n",
1225"res = linear_to_srgb(render_partial(envmap_gt, mask[0], {'albedo': materials['albedo'] * 0.0 + 0.15}, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))\n",
1226"plt.imshow(res)\n",
1227"plt.figure()\n",
1228"plt.imshow(imgs_gt[ind].reshape(H, W, 3))"
1229]
1230},
1231{
1232"cell_type": "code",
1233"execution_count": null,
1234"metadata": {
1235"id": "LDMPMheWVG8G"
1236},
1237"outputs": [],
1238"source": [
1239"ind = 0\n",
1240"\n",
1241"sdf = params_to_sdf(sdf_params, jnp.array([ind]))\n",
1242"mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1243"\n",
1244"#envmap_top = np.zeros_like(envmap_gt)\n",
1245"#envmap_top[0, :, :] = 1000\n",
1246"#envmap_top[25, 49, :] = 10000\n",
1247"res = render_partial(envmap_gt, mask[0]*0.0+1.0,\n",
1248" {'albedo': materials['albedo'] * 0.0 + 0.15},\n",
1249" normals_gt[ind], -rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3)\n",
1250"plt.imshow(linear_to_srgb(res), interpolation='nearest')\n",
1251"plt.figure()\n",
1252"plt.imshow(imgs_gt[ind].reshape(H, W, 3))"
1253]
1254},
1255{
1256"cell_type": "code",
1257"execution_count": null,
1258"metadata": {
1259"id": "RTlmYw1CtsWn"
1260},
1261"outputs": [],
1262"source": [
1263"omega_phi.reshape(envmap_H, envmap_W)[0, 49]"
1264]
1265},
1266{
1267"cell_type": "code",
1268"execution_count": null,
1269"metadata": {
1270"id": "8dWB1SHRroGu"
1271},
1272"outputs": [],
1273"source": []
1274},
1275{
1276"cell_type": "code",
1277"execution_count": null,
1278"metadata": {
1279"id": "-CvSkOxhroXv"
1280},
1281"outputs": [],
1282"source": [
1283"plt.imshow(omega_x.reshape(envmap_H, envmap_W))"
1284]
1285},
1286{
1287"cell_type": "code",
1288"execution_count": null,
1289"metadata": {
1290"id": "Nzezbd06rodC"
1291},
1292"outputs": [],
1293"source": []
1294},
1295{
1296"cell_type": "code",
1297"execution_count": null,
1298"metadata": {
1299"id": "CG4EL5ySpm2U"
1300},
1301"outputs": [],
1302"source": [
1303"plt.plot(res[64])\n",
1304"print(jnp.nanmin(res) / jnp.nanmax(res))\n",
1305"print(jnp.nanmax(res) * 0.27823895, res[18, 64, 0])"
1306]
1307},
1308{
1309"cell_type": "code",
1310"execution_count": null,
1311"metadata": {
1312"id": "H08mIu3Tj9B6"
1313},
1314"outputs": [],
1315"source": [
1316"print(alpha_gt[0].reshape(H, W)[18, :].sum())\n",
1317"print(alpha_gt[0].reshape(H, W)[18, 64])"
1318]
1319},
1320{
1321"cell_type": "code",
1322"execution_count": null,
1323"metadata": {
1324"id": "xVE4XVbBj5fz"
1325},
1326"outputs": [],
1327"source": [
1328"normals_gt[0].reshape(H, W, 3)[18, 64, :]"
1329]
1330},
1331{
1332"cell_type": "code",
1333"execution_count": null,
1334"metadata": {
1335"id": "Kijvl705lOPr"
1336},
1337"outputs": [],
1338"source": [
1339"def jon(x, eps=1e-7):\n",
1340" denom_sq = x ** 2\n",
1341" normal = x / jnp.sqrt(jnp.maximum(denom_sq, eps))\n",
1342" return jnp.where(denom_sq \u003c eps, jnp.zeros_like(normal), normal)\n",
1343"\n",
1344"def dor(x, eps=1e-7):\n",
1345" return x / jnp.sqrt(jnp.maximum(x**2, eps))\n",
1346"\n",
1347"\n",
1348"x = jnp.linspace(-0.001, 0.001, 10000)\n",
1349"plt.plot(x, jon(x), x, dor(x))"
1350]
1351},
1352{
1353"cell_type": "code",
1354"execution_count": null,
1355"metadata": {
1356"id": "OMCb95CSPk7R"
1357},
1358"outputs": [],
1359"source": [
1360"ind = 0\n",
1361"\n",
1362"print(rays_d_vec[ind].reshape(H, W, 3)[0, W//2, :]) # Top is x\n",
1363"print(rays_d_vec[ind].reshape(H, W, 3)[H//2, 1, :]) # Left is y\n",
1364"\n",
1365"n = (normals_gt[ind].reshape(H, W, 3) @ R.T) * 0.5 + 0.5\n",
1366"plt.imshow(n)"
1367]
1368},
1369{
1370"cell_type": "code",
1371"execution_count": null,
1372"metadata": {
1373"id": "8jAr6Lcp9dl8"
1374},
1375"outputs": [],
1376"source": []
1377},
1378{
1379"cell_type": "code",
1380"execution_count": null,
1381"metadata": {
1382"id": "KNyuQ0oryIWO"
1383},
1384"outputs": [],
1385"source": [
1386"plt.imshow(envmap_gt)\n",
1387"plt.figure()\n",
1388"sdf = params_to_sdf(sdf_params, jnp.array([0]))\n",
1389"mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1390"\n",
1391"plt.imshow(mask[0])"
1392]
1393},
1394{
1395"cell_type": "code",
1396"execution_count": null,
1397"metadata": {
1398"id": "tvbSIOQk9OX7"
1399},
1400"outputs": [],
1401"source": []
1402},
1403{
1404"cell_type": "code",
1405"execution_count": null,
1406"metadata": {
1407"id": "6Tx54EIFxFPZ"
1408},
1409"outputs": [],
1410"source": [
1411"plt.imshow(masks_gt[-1])"
1412]
1413},
1414{
1415"cell_type": "code",
1416"execution_count": null,
1417"metadata": {
1418"id": "gwCzHL7CykR3"
1419},
1420"outputs": [],
1421"source": [
1422"#plt.figure(figsize=[12, 12])\n",
1423"#plt.plot(masks_gt[-1][34, :], '.')\n",
1424"plt.plot((1-masks_gt[-1]).sum(1), '.')"
1425]
1426},
1427{
1428"cell_type": "code",
1429"execution_count": null,
1430"metadata": {
1431"id": "hyD_ca65q9aK"
1432},
1433"outputs": [],
1434"source": [
1435"# Plot rendered image and GT image\n",
1436"sdf = params_to_sdf(sdf_params, jnp.array([ind]))\n",
1437"mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1438"if 'envmap' in gt_list:\n",
1439" envmap = exposure * envmap_gt\n",
1440"if 'mask' in gt_list:\n",
1441" mask = masks_gt[ind]\n",
1442"if 'materials' in gt_list:\n",
1443" materials = jax.tree_map(lambda x: x[ind], materials_gt)\n",
1444"res = linear_to_srgb(render_partial(envmap, mask[0], materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))\n",
1445"plt.figure()\n",
1446"plt.subplot(121)\n",
1447"plt.imshow(res)\n",
1448"plt.axis('off')\n",
1449"plt.subplot(122)\n",
1450"plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1451"plt.axis('off')\n",
1452"plt.show()\n",
1453"\n",
1454"diff = imgs_gt[ind].reshape(H, W, 3) - res\n",
1455"#loss_per_element = (diff ** 2).sum(-1)\n",
1456"#loss_per_element = jnp.abs(diff).sum(-1)\n",
1457"if False:\n",
1458" p = 2 - 1.5 * i / num_iters\n",
1459" data_loss = jnp.power((jnp.abs(diff + 1e-10) ** p).sum(), 1/p)\n",
1460" print(\"Using graduated nonconvexity in the loss\")\n",
1461"else:\n",
1462" data_loss = (jnp.abs(diff) ** 2).sum()\n",
1463" print(\"Using L2 loss\")\n",
1464"\n",
1465"data_loss = data_loss / 1 / 128 / 128\n",
1466"print(data_loss)"
1467]
1468},
1469{
1470"cell_type": "code",
1471"execution_count": null,
1472"metadata": {
1473"id": "iZii_KyJrN17"
1474},
1475"outputs": [],
1476"source": [
1477"diff = res - imgs_gt[ind].reshape(H, W, 3)\n",
1478"diff.min(), diff.max()"
1479]
1480},
1481{
1482"cell_type": "code",
1483"execution_count": null,
1484"metadata": {
1485"id": "5G4_YPmorCnm"
1486},
1487"outputs": [],
1488"source": [
1489"plt.imshow(mask[0])"
1490]
1491},
1492{
1493"cell_type": "code",
1494"execution_count": null,
1495"metadata": {
1496"id": "URJTcUOo9rdb"
1497},
1498"outputs": [],
1499"source": []
1500},
1501{
1502"cell_type": "code",
1503"execution_count": null,
1504"metadata": {
1505"id": "bhkb4AWsGzmn"
1506},
1507"outputs": [],
1508"source": [
1509" media.show_video(envmaps, height=300)"
1510]
1511},
1512{
1513"cell_type": "code",
1514"execution_count": null,
1515"metadata": {
1516"id": "jsUCKGrFbyt7"
1517},
1518"outputs": [],
1519"source": [
1520"num_rows = 6\n",
1521"cameras_to_plot = [int(x) for x in np.linspace(0, N_cameras-1, num_rows)]\n",
1522"sdf_params = get_params_sdf(flax.jax_utils.unreplicate(replicated_state_sdf))\n",
1523"sdfs = params_to_sdf(sdf_params, jnp.array(cameras_to_plot))\n",
1524"plt.imshow(jnp.float32(sdfs[2] \u003e 0) - jnp.float32(sdfs[1] \u003e 0))"
1525]
1526},
1527{
1528"cell_type": "code",
1529"execution_count": null,
1530"metadata": {
1531"id": "p2lbNspacFqv"
1532},
1533"outputs": [],
1534"source": [
1535"plt.imshow(sdfs[3] \u003e 0)"
1536]
1537},
1538{
1539"cell_type": "code",
1540"execution_count": null,
1541"metadata": {
1542"id": "_cNuUtm2Tz3L"
1543},
1544"outputs": [],
1545"source": [
1546"jnp.where(alpha_gt == 1.0, imgs_gt, 0.0).max()"
1547]
1548},
1549{
1550"cell_type": "code",
1551"execution_count": null,
1552"metadata": {
1553"id": "YYEJmEDP4KKk"
1554},
1555"outputs": [],
1556"source": [
1557"plt.imshow(np.float32(imgs_gt[5].reshape(H, W, 3) == 1))"
1558]
1559},
1560{
1561"cell_type": "code",
1562"execution_count": null,
1563"metadata": {
1564"id": "WKPz2rHYKZHV"
1565},
1566"outputs": [],
1567"source": [
1568"media.show_video(envmaps, height=200)"
1569]
1570},
1571{
1572"cell_type": "code",
1573"execution_count": null,
1574"metadata": {
1575"id": "gc21N4x__zqC"
1576},
1577"outputs": [],
1578"source": [
1579"ind = 11\n",
1580"sdf = params_to_sdf(sdf_params, jnp.array([ind]))\n",
1581"mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1582"envmap = envmap_gt\n",
1583"res = linear_to_srgb(render_partial(envmap, mask[0], materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))\n",
1584"plt.figure()\n",
1585"plt.subplot(121)\n",
1586"plt.imshow(res)\n",
1587"plt.axis('off')\n",
1588"plt.subplot(122)\n",
1589"plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1590"plt.axis('off')\n",
1591"plt.show()\n"
1592]
1593},
1594{
1595"cell_type": "code",
1596"execution_count": null,
1597"metadata": {
1598"id": "XLeHIuktbhJN"
1599},
1600"outputs": [],
1601"source": [
1602"omega_xyz.shape"
1603]
1604},
1605{
1606"cell_type": "code",
1607"execution_count": null,
1608"metadata": {
1609"id": "BESEhvA8iJ6h"
1610},
1611"outputs": [],
1612"source": [
1613"ind\n",
1614"\n",
1615"materials = jax.tree_map(lambda x: x[ind], materials_gt)\n",
1616"envmap = envmap_gt\n",
1617"\n",
1618"mask = jnp.sum(-omega_xyz * rays_d_vec[ind, H//2*W+W//2, :][None, :], axis=-1) \u003c 0.76649692 # Occluder is aligned with the camera\n",
1619"\n",
1620"res = linear_to_srgb(render_partial(envmap, mask.reshape(envmap_H, envmap_W), materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))\n",
1621"#diff = res - imgs_gt[ind].reshape(H, W, 3)\n",
1622"plt.figure()\n",
1623"plt.imshow(res)\n",
1624"\"\"\"\n",
1625"plt.figure()\n",
1626"plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1627"plt.figure()\n",
1628"err = np.abs(diff).sum(-1)\n",
1629"plt.imshow(err, cmap='gray')\n",
1630"plt.colorbar()\n",
1631"\"\"\";"
1632]
1633},
1634{
1635"cell_type": "code",
1636"execution_count": null,
1637"metadata": {
1638"id": "l2d7rmUmcRGH"
1639},
1640"outputs": [],
1641"source": [
1642"plt.imshow(normals_gt[ind].reshape(H, W, 3))"
1643]
1644},
1645{
1646"cell_type": "code",
1647"execution_count": null,
1648"metadata": {
1649"id": "PTMdZXyboxE-"
1650},
1651"outputs": [],
1652"source": [
1653"envmap = envmap_gt\n",
1654"\n",
1655"rows = []\n",
1656"for i in range(H):\n",
1657" inds = np.arange(W) + i * W\n",
1658" materials = jax.tree_map(lambda x: x[ind, inds], materials_gt)\n",
1659" res = linear_to_srgb(render_partial(envmap, envmap[..., 0]*0.0+1.0, materials, normals_gt[ind, inds], rays_d_r[ind, inds], alpha_gt[ind, inds]))\n",
1660" rows.append(res)\n",
1661"\n",
1662"res = jnp.concatenate(rows, axis=0).reshape(H, W, 3)\n",
1663"diff = res - imgs_gt[ind].reshape(H, W, 3)\n",
1664"plt.figure()\n",
1665"plt.imshow(res)\n",
1666"plt.figure()\n",
1667"plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1668"plt.figure()\n",
1669"err = np.abs(diff).sum(-1)\n",
1670"plt.imshow(err, cmap='gray')\n",
1671"plt.colorbar()\n"
1672]
1673},
1674{
1675"cell_type": "code",
1676"execution_count": null,
1677"metadata": {
1678"id": "98uykDQhWc1b"
1679},
1680"outputs": [],
1681"source": [
1682"envmap = envmap_linear\n",
1683"\n",
1684"rows = []\n",
1685"for i in range(H):\n",
1686" cols = []\n",
1687" for j in range(W):\n",
1688" inds = np.arange(1) + i * W + j\n",
1689" materials = jax.tree_map(lambda x: x[ind, inds], materials_gt)\n",
1690" res = linear_to_srgb(render_partial(envmap, envmap[..., 0]*0.0+1.0, materials, normals_gt[ind, inds], rays_d_r[ind, inds], alpha_gt[ind, inds]))\n",
1691" cols.append(res)\n",
1692" row = jnp.concatenate(cols, axis=0)\n",
1693" rows.append(row)\n",
1694"\n",
1695"res = jnp.concatenate(rows, axis=0).reshape(H, W, 3)\n",
1696"diff = res - imgs_gt[ind].reshape(H, W, 3)\n",
1697"plt.figure()\n",
1698"plt.imshow(res)\n",
1699"plt.figure()\n",
1700"plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1701"plt.figure()\n",
1702"err = np.abs(diff).sum(-1)\n",
1703"plt.imshow(err, cmap='gray')\n",
1704"plt.colorbar()\n"
1705]
1706},
1707{
1708"cell_type": "code",
1709"execution_count": null,
1710"metadata": {
1711"id": "b6Mo5B8ba3Il"
1712},
1713"outputs": [],
1714"source": [
1715"res = jnp.concatenate(rows, axis=0).reshape(-1, W, 3)\n",
1716"#diff = res - imgs_gt[ind].reshape(H, W, 3)\n",
1717"plt.figure()\n",
1718"plt.imshow(res)\n",
1719"plt.figure()\n",
1720"plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1721"plt.figure()\n",
1722"err = np.abs(diff).sum(-1)\n",
1723"plt.imshow(err, cmap='gray')\n",
1724"plt.colorbar()\n"
1725]
1726},
1727{
1728"cell_type": "code",
1729"execution_count": null,
1730"metadata": {
1731"id": "vKdgW6WSjFmF"
1732},
1733"outputs": [],
1734"source": [
1735"materials['albedo'].shape"
1736]
1737},
1738{
1739"cell_type": "code",
1740"execution_count": null,
1741"metadata": {
1742"id": "lf7lQmu2i-o-"
1743},
1744"outputs": [],
1745"source": [
1746"plt.plot(err[:, 64])"
1747]
1748},
1749{
1750"cell_type": "code",
1751"execution_count": null,
1752"metadata": {
1753"id": "ixZYlvq8jPT1"
1754},
1755"outputs": [],
1756"source": [
1757"def interp2d(grids, inds):\n",
1758" \"\"\"\n",
1759" grids is [H, W, d]\n",
1760" inds is [2, ...], with the 0th dim being elevation and 1st azimuth\n",
1761" \"\"\"\n",
1762"\n",
1763" results = []\n",
1764" for grid in [grids[:, :, d] for d in range(grids.shape[-1])]:\n",
1765" res = jax.scipy.ndimage.map_coordinates(grid, inds, order=1, mode='wrap')\n",
1766" results.append(res)\n",
1767"\n",
1768" return jnp.stack(results, axis=-1)\n",
1769"\n",
1770"\n",
1771"def render_pixel_mirror(refdir, envmap, mask):\n",
1772" x, y, z = refdir\n",
1773" theta = jnp.arctan2(jnp.sqrt(x ** 2 + y ** 2), z)\n",
1774" phi = jnp.arctan2(y, x)\n",
1775" # Quantize to get index\n",
1776" theta_ind = jnp.floor(envmap_H * theta / jnp.pi).astype(jnp.int32)\n",
1777" phi_ind = jnp.round(envmap_W * phi / 2.0 / jnp.pi).astype(jnp.int32)\n",
1778" return (envmap * mask[:, :, None])[theta_ind, phi_ind]\n",
1779" #return interp2d(envmap * mask[:, :, None], [theta*(envmap_H-1)/jnp.pi, phi*(envmap_W-1)/2/jnp.pi])\n",
1780"\n",
1781"def render_mirror(envmap, mask, normals, rays_d, alpha, oxyz, rad, shape='sphere'):\n",
1782" \"\"\"\n",
1783" envmap: shape [h, w, 3]\n",
1784" mask: shape [h, w]\n",
1785" materials: dictionary with entries of shape [N, 3]\n",
1786" normals: shape [N, 3]\n",
1787" rays_d: shape [N, 3]\n",
1788" alpha: shape [N, 1]\n",
1789" oxyz: shape [1, 3] (shape center)\n",
1790" rad: float (shape radius)\n",
1791" \n",
1792" output: rendered colors, shape [N, 3]\n",
1793" \"\"\"\n",
1794" d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)\n",
1795" rays_d_norm = rays_d / jnp.sqrt(d_norm_sq + 1e-10) # [N, 3]\n",
1796"\n",
1797" refdirs = 2.0 * (normals * rays_d_norm).sum(-1, keepdims=True) * normals - rays_d_norm # [N, 3]\n",
1798" print(refdirs.shape, envmap.shape, mask.shape)\n",
1799" colors = jax.vmap(render_pixel_mirror, in_axes=(0, None, None))(refdirs, envmap, mask) \n",
1800" \n",
1801" return colors * alpha\n",
1802"\n",
1803"img_ind = 0\n",
1804"d = -rays_d_vec[img_ind] #* jnp.array([1.0, -1.0, 1.0])\n",
1805"R = jnp.array([[ 0.0, 1.0, 0.0],\n",
1806" [-1.0, 0.0, 0.0],\n",
1807" [ 0.0, 0.0, 1.0]])\n",
1808"n = normals_gt[img_ind] @ R\n",
1809"\n",
1810"res = render_mirror(jnp.fliplr(envmap_gt), envmap_gt[:, :, 0]*0.0+1.0,\n",
1811" n, d, alpha_gt[img_ind], jnp.zeros((3,)), 1.0)\n",
1812"res_srgb = linear_to_srgb(res.reshape(H, W, 3))\n",
1813"plt.figure(figsize=[12, 12])\n",
1814"plt.imshow(res_srgb, interpolation='nearest')\n",
1815"plt.axis('off')"
1816]
1817},
1818{
1819"cell_type": "code",
1820"execution_count": null,
1821"metadata": {
1822"id": "6LjFQh_tcFiA"
1823},
1824"outputs": [],
1825"source": [
1826"plt.imshow(normals_gt[0].reshape(H, W, 3) * 0.5 + 0.5)"
1827]
1828},
1829{
1830"cell_type": "code",
1831"execution_count": null,
1832"metadata": {
1833"id": "9_bzZTokbi0r"
1834},
1835"outputs": [],
1836"source": [
1837"rays_d_vec[0].reshape(H, W, 3)[0, W//2, :]"
1838]
1839},
1840{
1841"cell_type": "code",
1842"execution_count": null,
1843"metadata": {
1844"id": "RDksiHZkh8lF"
1845},
1846"outputs": [],
1847"source": [
1848"plt.imshow(envmap_gt)"
1849]
1850},
1851{
1852"cell_type": "code",
1853"execution_count": null,
1854"metadata": {
1855"id": "lXdTJL1Ozw1Y"
1856},
1857"outputs": [],
1858"source": [
1859"with open('{DIRECTORY}/r_0_lamb.png', 'rb') as f:\n",
1860" img_lamb = np.array(Image.open(f))[:, :, :3] / 255.0\n"
1861]
1862},
1863{
1864"cell_type": "code",
1865"execution_count": null,
1866"metadata": {
1867"id": "raQb_af6vO1M"
1868},
1869"outputs": [],
1870"source": [
1871"plt.imshow(alpha_gt[ind].reshape(H, W))"
1872]
1873},
1874{
1875"cell_type": "code",
1876"execution_count": null,
1877"metadata": {
1878"id": "7F65WtLmReQL"
1879},
1880"outputs": [],
1881"source": [
1882"ind = 20\n",
1883"\n",
1884"with open(f'{DIRECTORY}/r_{ind}_larger.png', 'rb') as f:\n",
1885" res_gt = np.float32(Image.open(f)) / 255.0\n",
1886" alpha = res_gt[:, :, 3:]\n",
1887" res_gt = res_gt[:, :, :3] * alpha\n",
1888"\n",
1889"#elevation = omega_theta.reshape(envmap_H, envmap_W) / jnp.pi * (envmap_H - 1 + 1)\n",
1890"#azimuth = omega_phi.reshape(envmap_H, envmap_W) / (2.0 * jnp.pi) * (envmap_W - 1) + 0.5 # + 4.0\n",
1891"#inds = jnp.stack([jnp.mod(elevation, envmap_H), jnp.mod(azimuth, envmap_W)], axis=0)\n",
1892"#envmap = interp2d(envmap_gt, inds)\n",
1893"#print(envmap.shape)\n",
1894"envmap = envmap_gt\n",
1895"\n",
1896"rows = []\n",
1897"for i in range(H):\n",
1898" inds = np.arange(W) + i * W\n",
1899" materials = jax.tree_map(lambda x: x[ind, inds], materials_gt)\n",
1900" res = linear_to_srgb(render_partial(envmap, envmap[..., 0]*0.0+1.0, {'albedo': jnp.ones((W, 3))*0.15}, normals_gt[ind, inds], rays_d_r[ind, inds], alpha.reshape(-1, 1)[inds]))\n",
1901" rows.append(res)\n",
1902"\n",
1903"res = jnp.concatenate(rows, axis=0).reshape(H, W, 3)\n",
1904"plt.imshow(res)\n",
1905"plt.figure()\n",
1906"\n",
1907"plt.imshow(res_gt)\n",
1908"\n",
1909"plt.figure()\n",
1910"diff = res - res_gt\n",
1911"plt.imshow(jnp.abs(diff).sum(-1)/3, cmap='gray')\n",
1912"plt.colorbar()\n",
1913"\n",
1914"print(jnp.abs(diff[30:50, 80:90]).max())"
1915]
1916},
1917{
1918"cell_type": "code",
1919"execution_count": null,
1920"metadata": {
1921"id": "udan6pb0Tk7b"
1922},
1923"outputs": [],
1924"source": []
1925},
1926{
1927"cell_type": "code",
1928"execution_count": null,
1929"metadata": {
1930"id": "CGoSY7aIRmpz"
1931},
1932"outputs": [],
1933"source": []
1934},
1935{
1936"cell_type": "code",
1937"execution_count": null,
1938"metadata": {
1939"collapsed": true,
1940"id": "vr0zZ_LavnOQ"
1941},
1942"outputs": [],
1943"source": [
1944"ind = 0\n",
1945"res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind] @ R, -rays_d_vec[ind], alpha_gt[ind]).reshape(H, W, 3)\n",
1946"#res = render(jnp.fliplr(envmap_gt), envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind] @ R.T, -rays_d_vec[ind], alpha_gt[ind]).reshape(H, W, 3)\n",
1947"# ???????????\n",
1948"plt.imshow(linear_to_srgb(res))\n",
1949"plt.figure()\n",
1950"plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1951"plt.figure()"
1952]
1953},
1954{
1955"cell_type": "code",
1956"execution_count": null,
1957"metadata": {
1958"id": "XXPbbl_Bq07h"
1959},
1960"outputs": [],
1961"source": [
1962"ind = 20\n",
1963"#with open(f'{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_lowres_envmap_uniform_linear_128x128/test/r_{ind}.png', 'rb') as f:\n",
1964"with open(f'{DIRECTORY}/r_{ind}_envmap_nn.png', 'rb') as f:\n",
1965" res_gt = np.float32(Image.open(f)) / 255.0\n",
1966" alpha = res_gt[:, :, 3:]\n",
1967" res_gt = res_gt[:, :, :3] * alpha\n",
1968"res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)\n",
1969"#print(res[60:70, 60:70, :])\n",
1970"res = linear_to_srgb(res)\n",
1971"plt.subplot(121)\n",
1972"plt.imshow(res)\n",
1973"plt.axis('off')\n",
1974"plt.subplot(122)\n",
1975"#plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1976"plt.imshow(res_gt)\n",
1977"plt.axis('off')\n"
1978]
1979},
1980{
1981"cell_type": "code",
1982"execution_count": null,
1983"metadata": {
1984"id": "Hrd8TAj9Hxc0"
1985},
1986"outputs": [],
1987"source": [
1988"#plt.imshow(3*(res - res_gt))\n",
1989"plt.imshow(0.5 * srgb_to_linear(res_gt) / srgb_to_linear(res))\n",
1990"plt.axis('off')"
1991]
1992},
1993{
1994"cell_type": "code",
1995"execution_count": null,
1996"metadata": {
1997"id": "O_QxjQomHiSb"
1998},
1999"outputs": [],
2000"source": [
2001"plt.scatter(res[:, :, 0], res[:, :, 1])\n",
2002"plt.figure()\n",
2003"plt.scatter(res_gt[:, :, 0], res_gt[:, :, 1])"
2004]
2005},
2006{
2007"cell_type": "code",
2008"execution_count": null,
2009"metadata": {
2010"id": "7XzYX89pn78B"
2011},
2012"outputs": [],
2013"source": [
2014"plt.imshow(res - res_gt, cmap='gray')\n"
2015]
2016},
2017{
2018"cell_type": "code",
2019"execution_count": null,
2020"metadata": {
2021"id": "jXFCPk7dY7Eg"
2022},
2023"outputs": [],
2024"source": [
2025"plt.imshow((res - res_gt).sum(-1) / 3.0, cmap='gray')\n",
2026"plt.colorbar()"
2027]
2028},
2029{
2030"cell_type": "code",
2031"execution_count": null,
2032"metadata": {
2033"id": "XbawH1gkzfyl"
2034},
2035"outputs": [],
2036"source": [
2037"ind = 70\n",
2038"#with open(f'{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_lowres_envmap_uniform_linear_128x128/test/r_{ind}.png', 'rb') as f:\n",
2039"with open(f'{DIRECTORY}/r_{ind}.png', 'rb') as f:\n",
2040" res_gt = np.float32(Image.open(f)) / 255.0\n",
2041" alpha = res_gt[:, :, 3:]\n",
2042" res_gt = res_gt[:, :, :3] * alpha\n",
2043"res = render(envmap_gt*0.0 + jnp.where(jnp.cos(omega_phi.reshape(envmap_H, envmap_W, 1) + 0.0) \u003c 0, 1.0, 0.0),\n",
2044" envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*1.0}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)\n",
2045"#print(res[60:70, 60:70, :])\n",
2046"res = linear_to_srgb(res)\n",
2047"plt.subplot(121)\n",
2048"plt.imshow(res, vmin=0.0, vmax=1.0)\n",
2049"plt.axis('off')\n",
2050"plt.subplot(122)\n",
2051"#plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
2052"plt.imshow(res_gt, vmin=0.0, vmax=1.0)\n",
2053"plt.axis('off')\n",
2054"plt.figure()\n",
2055"diff = res - res_gt\n",
2056"plt.imshow(-diff[:, :, 0], cmap='gray')\n",
2057"plt.colorbar()"
2058]
2059},
2060{
2061"cell_type": "code",
2062"execution_count": null,
2063"metadata": {
2064"id": "4pvhvCCPOqij"
2065},
2066"outputs": [],
2067"source": [
2068"res[70:90, 70:90, :].min(), res_gt[70:90, 70:90, :].min()"
2069]
2070},
2071{
2072"cell_type": "code",
2073"execution_count": null,
2074"metadata": {
2075"id": "3nmFDWBzOzaj"
2076},
2077"outputs": [],
2078"source": [
2079"diff[60, 60, :]"
2080]
2081},
2082{
2083"cell_type": "code",
2084"execution_count": null,
2085"metadata": {
2086"id": "rAaevaJN1b6E"
2087},
2088"outputs": [],
2089"source": []
2090},
2091{
2092"cell_type": "code",
2093"execution_count": null,
2094"metadata": {
2095"id": "M1qNaYMqeGHs"
2096},
2097"outputs": [],
2098"source": [
2099"a = alpha_gt[img_ind].reshape(H, W, 1)\n",
2100"plt.figure(figsize=[12, 12])\n",
2101"plt.subplot(121)\n",
2102"plt.imshow(res_srgb * a + (1.0 - a))\n",
2103"plt.axis('off')\n",
2104"plt.subplot(122)\n",
2105"plt.imshow(img)\n",
2106"plt.axis('off')"
2107]
2108},
2109{
2110"cell_type": "code",
2111"execution_count": null,
2112"metadata": {
2113"id": "mXzlZ7xyeazs"
2114},
2115"outputs": [],
2116"source": [
2117"media.show_video([img_ggx, res_srgb], fps=2, height=200)"
2118]
2119},
2120{
2121"cell_type": "code",
2122"execution_count": null,
2123"metadata": {
2124"id": "A3A_bf80cL79"
2125},
2126"outputs": [],
2127"source": [
2128"#with open('{DIRECTORY}/r_0.png', 'rb') as f:\n",
2129"# img = np.array(Image.open(f))[:, :, :3] / 255.0\n",
2130"\n",
2131"with open('{DIRECTORY}/r_0_true_mirror.png', 'rb') as f:\n",
2132" img_tm = np.array(Image.open(f))[:, :, :3] / 255.0\n",
2133"\n",
2134"with open('{DIRECTORY}/r_0_ggx_mirror.png', 'rb') as f:\n",
2135" img_ggx = np.array(Image.open(f))[:, :, :3] / 255.0\n"
2136]
2137},
2138{
2139"cell_type": "code",
2140"execution_count": null,
2141"metadata": {
2142"id": "vZuJPrZKdj-U"
2143},
2144"outputs": [],
2145"source": [
2146"plt.imshow(jnp.log10(jnp.abs(img - res_srgb).sum(-1)/3), cmap='gray')\n",
2147"plt.colorbar()"
2148]
2149},
2150{
2151"cell_type": "code",
2152"execution_count": null,
2153"metadata": {
2154"id": "jtz1eDbGsQtR"
2155},
2156"outputs": [],
2157"source": []
2158},
2159{
2160"cell_type": "code",
2161"execution_count": null,
2162"metadata": {
2163"id": "TNpwJkd0sgvc"
2164},
2165"outputs": [],
2166"source": [
2167"# Strength: 0.5\n",
2168"img_lamb[64, 64]"
2169]
2170},
2171{
2172"cell_type": "code",
2173"execution_count": null,
2174"metadata": {
2175"id": "GjmlQ241s7ij"
2176},
2177"outputs": [],
2178"source": [
2179"# Strength: 0.25\n",
2180"img_lamb[64, 64]"
2181]
2182},
2183{
2184"cell_type": "code",
2185"execution_count": null,
2186"metadata": {
2187"id": "BhsBMDZzshXb"
2188},
2189"outputs": [],
2190"source": [
2191"# Strength: 0.2\n",
2192"img_lamb[64, 64]"
2193]
2194},
2195{
2196"cell_type": "code",
2197"execution_count": null,
2198"metadata": {
2199"id": "5VexRaTAs2_T"
2200},
2201"outputs": [],
2202"source": [
2203"def linear_to_srgb(linear):\n",
2204" srgb0 = 323 / 25 * linear\n",
2205" srgb1 = (211 * linear**(5 / 12) - 11) / 200\n",
2206" return np.where(linear \u003c= 0.0031308, srgb0, srgb1)\n",
2207"\n",
2208"\n",
2209"def srgb_to_linear(srgb):\n",
2210" linear0 = srgb * 25 / 323\n",
2211" linear1 = ((200 * srgb + 11) / 211) ** (12 / 5)\n",
2212" return np.where(srgb \u003c= 0.0031308 * 25 / 323, linear0, linear1)\n",
2213"\n",
2214"srgb_to_linear(0.66666667), srgb_to_linear(0.48235294), srgb_to_linear(0.43529412)\n"
2215]
2216},
2217{
2218"cell_type": "code",
2219"execution_count": null,
2220"metadata": {
2221"id": "JEJzvVIltsvM"
2222},
2223"outputs": [],
2224"source": [
2225"def linear_to_srgb(linear):\n",
2226" srgb0 = 323 / 25 * linear\n",
2227" srgb1 = (211 * linear**(5 / 12) - 11) / 200\n",
2228" return np.where(linear \u003c= 0.0031308, srgb0, srgb1)\n",
2229"\n",
2230"print(linear_to_srgb(0.5))"
2231]
2232},
2233{
2234"cell_type": "code",
2235"execution_count": null,
2236"metadata": {
2237"id": "ku7QqJbyFPBe"
2238},
2239"outputs": [],
2240"source": [
2241"ind = 0\n",
2242"with open(f'{DIRECTORY}/r_{ind}.exr', 'rb') as f:\n",
2243" #res_gt = np.float32(Image.open(f)) / 255.0\n",
2244" res_gt = imageio.imread(f, 'exr')\n",
2245"\n",
2246" alpha = res_gt[:, :, 3:]\n",
2247" res_gt = res_gt[:, :, :3] * alpha\n",
2248"\n",
2249"plt.imshow(res_gt, cmap='gray', interpolation='nearest')\n",
2250"\n",
2251"\n",
2252"res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)\n",
2253"res = res * alpha #res.repeat(3, 2) * alpha\n",
2254"plt.figure()\n",
2255"plt.imshow(res, cmap='gray', interpolation='nearest')\n",
2256"\n",
2257"plt.figure()\n",
2258"plt.imshow(jnp.abs(res - res_gt).sum(-1))\n",
2259"#plt.imshow(res[..., 1] - res_gt[..., 1])\n",
2260"plt.colorbar()"
2261]
2262},
2263{
2264"cell_type": "code",
2265"execution_count": null,
2266"metadata": {
2267"id": "2js5LOhdUk01"
2268},
2269"outputs": [],
2270"source": [
2271"ind = 0\n",
2272"with open(f'{DIRECTORY}/r_{ind}_bg.exr', 'rb') as f:\n",
2273" #res_gt = np.float32(Image.open(f)) / 255.0\n",
2274" res_gt = imageio.imread(f, 'exr')\n",
2275"\n",
2276" alpha = res_gt[:, :, 3:]\n",
2277" res_gt = res_gt[:, :, :3] * alpha\n",
2278"\n",
2279"res_gt_srgb = linear_to_srgb(res_gt)\n",
2280"plt.imshow(res_gt_srgb, interpolation='nearest')\n",
2281"\n",
2282"\n",
2283"#omega_phi, omega_theta\n",
2284"dirs = rays_d_vec[ind].reshape(H, W, 3)\n",
2285"#elevation = omega_theta.reshape(envmap_H, envmap_W) / jnp.pi * (envmap_H - 1 + 1)\n",
2286"#azimuth = omega_phi.reshape(envmap_H, envmap_W) / (2.0 * jnp.pi) * (envmap_W - 1) + 0.5\n",
2287"#inds = jnp.stack([elevation, azimuth], axis=0)\n",
2288"#env = interp2d(envmap_gt, inds)\n",
2289"dirs_azimuth = np.arctan2(dirs[..., 1], dirs[..., 0])\n",
2290"nr = np.sqrt(dirs[..., 1] ** 2 + dirs[..., 0] ** 2)\n",
2291"dirs_elevation = np.arctan2(nr, dirs[..., 2])\n",
2292"import scipy\n",
2293"plt.figure()\n",
2294"channels = []\n",
2295"for i in range(3):\n",
2296" interp = scipy.interpolate.interp2d(omega_phi.reshape(envmap_H, envmap_W)[0, :], omega_theta.reshape(envmap_H, envmap_W)[:, 0], envmap_gt[:, :, i])\n",
2297" #ch = interp(omega_phi.reshape(envmap_H, envmap_W)[0, :], omega_theta.reshape(envmap_H, envmap_W)[:, 0])\n",
2298" ch = interp(dirs_azimuth[0, :], dirs_elevation[:, 0])\n",
2299" channels.append(ch)\n",
2300"#plt.imshow(interp(omega_phi, omega_theta).reshape(envmap_H, envmap_W))\n",
2301"plt.imshow(linear_to_srgb(jnp.stack(channels, axis=-1)))\n",
2302"#res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)\n",
2303"#res = res * alpha #res.repeat(3, 2) * alpha\n",
2304"#plt.figure()\n",
2305"#plt.imshow(res, cmap='gray', interpolation='nearest')\n",
2306"\n",
2307"#plt.figure()\n",
2308"#plt.imshow(jnp.abs(res - res_gt).sum(-1))\n",
2309"#plt.imshow(res[..., 1] - res_gt[..., 1])\n",
2310"#plt.colorbar()"
2311]
2312},
2313{
2314"cell_type": "code",
2315"execution_count": null,
2316"metadata": {
2317"id": "h8DjVEBMuDwJ"
2318},
2319"outputs": [],
2320"source": [
2321"# When shifting by half a pixel this is the error"
2322]
2323},
2324{
2325"cell_type": "code",
2326"execution_count": null,
2327"metadata": {
2328"id": "BhPEi4MioGj3"
2329},
2330"outputs": [],
2331"source": [
2332"plt.imshow(envmap_gt)"
2333]
2334},
2335{
2336"cell_type": "code",
2337"execution_count": null,
2338"metadata": {
2339"id": "tndUYnK7kVjr"
2340},
2341"outputs": [],
2342"source": [
2343"res.min(), res.mean(), res.max()"
2344]
2345},
2346{
2347"cell_type": "code",
2348"execution_count": null,
2349"metadata": {
2350"id": "FkrC_npUkgLk"
2351},
2352"outputs": [],
2353"source": [
2354"res_gt.min(), res_gt.mean(), res_gt.max()"
2355]
2356},
2357{
2358"cell_type": "code",
2359"execution_count": null,
2360"metadata": {
2361"id": "TYbKAZiUgJLJ"
2362},
2363"outputs": [],
2364"source": [
2365"media.show_video([res, res_gt], fps=2, height=256)"
2366]
2367},
2368{
2369"cell_type": "code",
2370"execution_count": null,
2371"metadata": {
2372"id": "jsEnZ94pl0Zq"
2373},
2374"outputs": [],
2375"source": [
2376"envmap_gt.min(), envmap_gt.max()"
2377]
2378},
2379{
2380"cell_type": "code",
2381"execution_count": null,
2382"metadata": {
2383"id": "_O3aObadmSAB"
2384},
2385"outputs": [],
2386"source": [
2387"envmap_gt.min(), envmap_gt.max()"
2388]
2389},
2390{
2391"cell_type": "code",
2392"execution_count": null,
2393"metadata": {
2394"id": "OTtZ2u2HuMvY"
2395},
2396"outputs": [],
2397"source": [
2398"ind = 1\n",
2399"with open(f'{DIRECTORY}/r_{ind}.exr', 'rb') as f:\n",
2400" #res_gt = np.float32(Image.open(f)) / 255.0\n",
2401" res_gt = imageio.imread(f, 'exr')\n",
2402"\n",
2403" alpha = res_gt[:, :, 3:]\n",
2404" res_gt = res_gt[:, :, :1] * alpha\n",
2405"\n",
2406"plt.imshow(res_gt, cmap='gray')\n",
2407"\n",
2408"res = jnp.maximum(0.0, (normals_gt[ind].reshape(H, W, 3) * jnp.array([0.0, 0.0, 1.0])[None, None, :]).sum(-1, keepdims=True)) / jnp.pi\n",
2409"res = res * alpha #res.repeat(3, 2) * alpha\n",
2410"plt.figure()\n",
2411"plt.imshow(res, cmap='gray')\n",
2412"\n",
2413"plt.figure()\n",
2414"plt.imshow(res - res_gt)\n",
2415"plt.colorbar()"
2416]
2417},
2418{
2419"cell_type": "code",
2420"execution_count": null,
2421"metadata": {
2422"id": "GoZrZ53vCFqf"
2423},
2424"outputs": [],
2425"source": [
2426"plt.imshow(res_gt - res)"
2427]
2428},
2429{
2430"cell_type": "code",
2431"execution_count": null,
2432"metadata": {
2433"id": "8E18U5PULGp_"
2434},
2435"outputs": [],
2436"source": [
2437"jnp.linalg.norm(normals_gt[0].reshape(H, W, 3), axis=-1).max()"
2438]
2439},
2440{
2441"cell_type": "code",
2442"execution_count": null,
2443"metadata": {
2444"id": "mxThbxIwLOHT"
2445},
2446"outputs": [],
2447"source": [
2448"def foo(rays_d, rays_o, rad=1.0):\n",
2449" d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)\n",
2450" o_norm_sq = (rays_o ** 2).sum(-1, keepdims=True)\n",
2451" d_dot_o = (rays_o * rays_d).sum(-1, keepdims=True)\n",
2452" disc = d_norm_sq * (rad ** 2 - o_norm_sq) + d_dot_o ** 2\n",
2453" alpha = jnp.float32(disc \u003e 0)\n",
2454" t_surface = jnp.where(disc \u003e 0, - jnp.sqrt(disc) - d_dot_o, jnp.inf) # [H, W, 1]\n",
2455"\n",
2456" pts = rays_o + rays_d * t_surface\n",
2457"\n",
2458" normals = pts / jnp.linalg.norm(pts, axis=-1, keepdims=True)\n",
2459"\n",
2460" plt.imshow(normals.reshape(H, W, 3) * 0.5 + 0.5)\n",
2461"\n",
2462"ind = 111\n",
2463"foo(rays_d_vec[ind], rays_o_vec[ind])\n",
2464"plt.figure()\n",
2465"#R = \n",
2466"n = normals_gt[ind]\n",
2467"plt.imshow(n.reshape(H, W, 3) * 0.5 + 0.5)"
2468]
2469},
2470{
2471"cell_type": "code",
2472"execution_count": null,
2473"metadata": {
2474"id": "BUx2MdOkIl2C"
2475},
2476"outputs": [],
2477"source": [
2478"rays_d_vec[-1].reshape(H, W, 3)[64, -1, :]"
2479]
2480},
2481{
2482"cell_type": "code",
2483"execution_count": null,
2484"metadata": {
2485"id": "-HgV69GTssmP"
2486},
2487"outputs": [],
2488"source": [
2489"n = normals_gt[0].reshape(H, W, 3) #@ R.T\n",
2490"plt.imshow(n * 0.5 + 0.5)"
2491]
2492},
2493{
2494"cell_type": "code",
2495"execution_count": null,
2496"metadata": {
2497"id": "lS6xAAxRxXeQ"
2498},
2499"outputs": [],
2500"source": []
2501},
2502{
2503"cell_type": "code",
2504"execution_count": null,
2505"metadata": {
2506"id": "aAzT6N1HtBHZ"
2507},
2508"outputs": [],
2509"source": [
2510"focal = .5 * W / np.tan(0.5 * 0.691111147403717)\n",
2511"pixtocams = camera_utils.get_pixtocam(focal, W, H)\n",
2512"c2w = jnp.array([[ 0.0, 1.0, 0.0, 0.0],\n",
2513" [-1.0, 0.0, 0.0, 0.0],\n",
2514" [ 0.0, 0.0, 1.0, 4.0],\n",
2515" [ 0.0, 0.0, 0.0, 1.0]])\n",
2516"origins, directions, _, _, _ = camera_utils.pixels_to_rays(\n",
2517" jnp.array([W/2.0]),\n",
2518" jnp.array([0]),\n",
2519" pixtocams,\n",
2520" c2w)\n",
2521"print(origins)\n",
2522"print(directions) # 'up' is x"
2523]
2524},
2525{
2526"cell_type": "code",
2527"execution_count": null,
2528"metadata": {
2529"id": "xtu32GGARIIj"
2530},
2531"outputs": [],
2532"source": []
2533},
2534{
2535"cell_type": "code",
2536"execution_count": null,
2537"metadata": {
2538"id": "k4fgXxvrRIqf"
2539},
2540"outputs": [],
2541"source": [
2542"rays_d_vec[0].reshape(H, W, 3)[0, 64, :]"
2543]
2544},
2545{
2546"cell_type": "code",
2547"execution_count": null,
2548"metadata": {
2549"id": "RtguZTjrLclp"
2550},
2551"outputs": [],
2552"source": [
2553"surface_pts = rays_o_vec + t_surface_gt * rays_d_vec\n",
2554"print(surface_pts.shape)"
2555]
2556},
2557{
2558"cell_type": "code",
2559"execution_count": null,
2560"metadata": {
2561"id": "aChrXPwo5bFD"
2562},
2563"outputs": [],
2564"source": [
2565"finite_surface_points.shape"
2566]
2567},
2568{
2569"cell_type": "code",
2570"execution_count": null,
2571"metadata": {
2572"id": "KfdcZm1vxZUu"
2573},
2574"outputs": [],
2575"source": [
2576"@jax.jit\n",
2577"def get_dists_sq(x, y):\n",
2578" return ((x - y) ** 2).sum(-1)\n",
2579"\n",
2580"def subsample_point_cloud(points, num_points):\n",
2581" new_points = [points[0]]\n",
2582" inds = []\n",
2583" for i in range(num_points - 1):\n",
2584" points_to_use_indices = np.random.choice(points.shape[0], size=(1000,), replace=False)\n",
2585" if i % 100 == 0:\n",
2586" print(i)\n",
2587" dists = get_dists_sq(points[points_to_use_indices, None, :], jnp.stack(new_points, axis=0)[None, :, :])\n",
2588" new_point_ind = jnp.argmax(dists.min(axis=1))\n",
2589" new_points.append(points[points_to_use_indices[new_point_ind]])\n",
2590" inds.append(points_to_use_indices[new_point_ind])\n",
2591" return new_points, inds\n",
2592"\n",
2593"surface_pts = surface_pts.reshape(-1, 3)\n",
2594"finite_surface_points = surface_pts[jnp.all(jnp.isfinite(surface_pts), axis=-1)]\n",
2595"surface_pts_subsampled = subsample_point_cloud(finite_surface_points, 10000)\n",
2596"\n",
2597"surface_pts_subsampled = jnp.stack(surface_pts_subsampled, axis=0)"
2598]
2599},
2600{
2601"cell_type": "code",
2602"execution_count": null,
2603"metadata": {
2604"id": "6s0ThFOJyx5l"
2605},
2606"outputs": [],
2607"source": [
2608"fig = plt.figure()\n",
2609"ax = fig.add_subplot(projection='3d')\n",
2610"\n",
2611"ax.scatter(surface_pts_subsampled[..., 0],\n",
2612" surface_pts_subsampled[..., 1],\n",
2613" surface_pts_subsampled[..., 2])\n"
2614]
2615},
2616{
2617"cell_type": "code",
2618"execution_count": null,
2619"metadata": {
2620"id": "PUELUiCD5Oxu"
2621},
2622"outputs": [],
2623"source": [
2624"with open('{DIRECTORY}/hotdog_surface_pts_subsampled.npy', 'wb') as f:\n",
2625" np.save(f, surface_pts_subsampled)\n"
2626]
2627},
2628{
2629"cell_type": "code",
2630"execution_count": null,
2631"metadata": {
2632"id": "gSahd5OS0hws"
2633},
2634"outputs": [],
2635"source": [
2636"plt.plot(surface_pts_subsampled[..., 0], surface_pts_subsampled[..., 1], '.')"
2637]
2638},
2639{
2640"cell_type": "code",
2641"execution_count": null,
2642"metadata": {
2643"id": "OrwX1dfmB45A"
2644},
2645"outputs": [],
2646"source": [
2647"surface_pts_subsampled.shape"
2648]
2649},
2650{
2651"cell_type": "markdown",
2652"metadata": {
2653"id": "oYslVMJlH0Zs"
2654},
2655"source": [
2656"# Cache visibility using MLP\n",
2657"\n",
2658"Optimize an MLP mapping from position and direction to visibility:\n",
2659"$$(\\mathbf{x}, \\boldsymbol{\\omega}) \\mapsto v$$\n",
2660"where $v$ is a scalar visibility in $[0, 1]$ (constrained by a sigmoid), and the position and direction are 3-vectors with positional encoding."
2661]
2662},
2663{
2664"cell_type": "code",
2665"execution_count": null,
2666"metadata": {
2667"id": "SCMrkNMuzyaM"
2668},
2669"outputs": [],
2670"source": [
2671"#with open('{DIRECTORY}/hotdog_surface_pts.npy', 'rb') as f:\n",
2672"# surface_pts = np.load(f)\n"
2673]
2674},
2675{
2676"cell_type": "code",
2677"execution_count": null,
2678"metadata": {
2679"id": "7CCAtFgjxwI2"
2680},
2681"outputs": [],
2682"source": []
2683},
2684{
2685"cell_type": "code",
2686"execution_count": null,
2687"metadata": {
2688"id": "b2MJEDxTxnfX"
2689},
2690"outputs": [],
2691"source": [
2692"surface_normals_subsampled.shape, omega_xyz.shape, occlusion_masks.shape"
2693]
2694},
2695{
2696"cell_type": "code",
2697"execution_count": null,
2698"metadata": {
2699"id": "rwOrxx5Qbnqj"
2700},
2701"outputs": [],
2702"source": [
2703"with open('{DIRECTORY}/hotdog_surface_pts_subsampled.npy', 'rb') as f:\n",
2704" surface_pts_subsampled = np.load(f)\n",
2705"\n",
2706"with open('{DIRECTORY}/subsampling_indices.npy', 'rb') as f:\n",
2707" indices = np.load(f)\n",
2708"\n",
2709"with open('{DIRECTORY}/visibility_images.npy', 'rb') as f:\n",
2710" occlusion_masks = jnp.float32(np.load(f)) / 255.0\n",
2711"\n",
2712"with open('{DIRECTORY}/hotdog_surface_normals_subsampled.npy', 'rb') as f:\n",
2713" surface_normals_subsampled = np.load(f)\n",
2714"\n",
2715"envmap_H, envmap_W = occlusion_masks.shape[1:]\n",
2716"\n",
2717"omega_phi, omega_theta = jnp.meshgrid(jnp.linspace(-jnp.pi, jnp.pi, envmap_W+1)[:-1] + 2.0 * jnp.pi / (2.0 * envmap_W),\n",
2718" jnp.linspace(0.0, jnp.pi, envmap_H+1)[:-1] + jnp.pi / (2.0 * envmap_H))\n",
2719"\n",
2720"dtheta_dphi = (omega_theta[1, 1] - omega_theta[0, 0]) * (omega_phi[1, 1] - omega_phi[0, 0])\n",
2721"\n",
2722"omega_theta = omega_theta.flatten()\n",
2723"omega_phi = omega_phi.flatten()\n",
2724"\n",
2725"omega_x = jnp.sin(omega_theta) * jnp.cos(omega_phi)\n",
2726"omega_y = jnp.sin(omega_theta) * jnp.sin(omega_phi)\n",
2727"omega_z = jnp.cos(omega_theta)\n",
2728"omega_xyz = jnp.stack([omega_x,\n",
2729" omega_y,\n",
2730" omega_z], axis=-1)\n",
2731"\n",
2732"\n",
2733"# Turn the negative hemisphere into nans\n",
2734"occlusion_masks = jnp.where(jnp.sum(surface_normals_subsampled[:, None, :] * omega_xyz[None, :, :], axis=-1).reshape(-1, envmap_H, envmap_W) \u003e 0.0,\n",
2735" occlusion_masks, jnp.nan)\n",
2736"# occlusion_masks, 0.0)\n"
2737]
2738},
2739{
2740"cell_type": "code",
2741"execution_count": null,
2742"metadata": {
2743"id": "vhGs5Ip9x7W-"
2744},
2745"outputs": [],
2746"source": [
2747"plt.imshow(occlusion_masks[70])"
2748]
2749},
2750{
2751"cell_type": "code",
2752"execution_count": null,
2753"metadata": {
2754"id": "C5YOTkoNR9mc"
2755},
2756"outputs": [],
2757"source": []
2758},
2759{
2760"cell_type": "code",
2761"execution_count": null,
2762"metadata": {
2763"id": "dZ8y1LMdB5cM"
2764},
2765"outputs": [],
2766"source": [
2767"class MLP(nn.Module):\n",
2768" features: Sequence[int]\n",
2769"\n",
2770" @nn.compact\n",
2771" def __call__(self, x, y=None):\n",
2772" if y is not None:\n",
2773" x = jnp.concatenate([x, y], axis=-1)\n",
2774" for feat in self.features[:-1]:\n",
2775" x = nn.relu(nn.Dense(feat)(x))\n",
2776" x = nn.Dense(self.features[-1])(x)\n",
2777" return x\n",
2778"\n",
2779"\n",
2780"append_identity = True\n",
2781"def posenc(x, L_encoding):\n",
2782" if L_encoding \u003c= 0:\n",
2783" return x\n",
2784" else:\n",
2785" scales = 2**jnp.arange(L_encoding)\n",
2786" #shape = x.shape[:-1] + (-1,)\n",
2787" #scaled_x = jnp.reshape((x[..., None, :] * scales[:, None]), shape)\n",
2788"\n",
2789" #four_feat = jnp.sin(\n",
2790" # jnp.concatenate([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1))\n",
2791" shape = x.shape[:-1] + (-1,)\n",
2792" scaled_x = x[..., None, :] * scales[:, None] # [..., L, D]\n",
2793"\n",
2794" four_feat = jnp.sin(\n",
2795" jnp.stack([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1)) # [..., L, D, 2]\n",
2796"\n",
2797" #four_feat = jnp.reshape(four_feat / scales[:, None, None], shape)\n",
2798" #print(\"Using Lipschitz posenc\")\n",
2799" four_feat = jnp.reshape(four_feat, shape)\n",
2800" if append_identity:\n",
2801" return jnp.concatenate([x] + [four_feat], axis=-1)\n",
2802" else:\n",
2803" return four_feat\n",
2804"\n",
2805"\n",
2806"# Initialize material MLP\n",
2807"L_encoding_vis_x = 3\n",
2808"L_encoding_vis_dir = 6\n",
2809"mlp_input_features = 6 + 6 * (L_encoding_vis_x + L_encoding_vis_dir)\n",
2810"\n",
2811"num_components = 1\n",
2812"mlp_vis = MLP([128]*4 + [num_components])\n",
2813"\n",
2814"params_vis = mlp_vis.init(jax.random.PRNGKey(0),\n",
2815" np.zeros([1, mlp_input_features]))\n",
2816"\n",
2817"init_lr_vis = 1e-2\n",
2818"\n",
2819"init_vis, update_vis, get_params_vis = jax.experimental.optimizers.adam(init_lr_vis)\n",
2820"state_vis = init_vis(params_vis)\n",
2821"\n",
2822"#x_input = surface_pts_subsampled.reshape(-1, 1, 3).repeat(envmap_H * envmap_W, axis=1).reshape(-1, 3)\n",
2823"#dir_input = omega_xyz.reshape(1, -1, 3).repeat(surface_pts_subsampled.shape[0], axis=0).reshape(-1, 3)\n",
2824"#vis_gt = jnp.zeros(())\n",
2825"\n",
2826"\n",
2827"def get_vis_loss(params, x_input, dir_input, vis_gt):\n",
2828"\n",
2829" vis_pred = jax.nn.sigmoid(mlp_vis.apply(params,\n",
2830" posenc(x_input, L_encoding_vis_x),\n",
2831" posenc(dir_input, L_encoding_vis_dir)))\n",
2832"\n",
2833" loss = jnp.nansum((vis_pred - vis_gt) ** 2) / x_batch_size / envmap_H / envmap_W\n",
2834"\n",
2835" return loss\n",
2836"\n",
2837"@jax.jit\n",
2838"def step(state, x_input, dir_input, vis_gt, i):\n",
2839" params = get_params_vis(state)\n",
2840" loss, grad = jax.value_and_grad(get_vis_loss)(params, x_input, dir_input, vis_gt)\n",
2841"\n",
2842" return update_vis(i, grad, state), loss\n",
2843"\n",
2844"#get_vis_loss(state_vis)\n",
2845"\n",
2846"x_batch_size = 50\n",
2847"\n",
2848"losses = []\n",
2849"for i in range(500):\n",
2850" image_indices = np.random.randint(0, surface_pts_subsampled.shape[0], size=(x_batch_size,))\n",
2851" \n",
2852" x_input = surface_pts_subsampled[image_indices].reshape(-1, 1, 3).repeat(envmap_H * envmap_W, axis=1).reshape(-1, 3)\n",
2853" dir_input = omega_xyz.reshape(1, -1, 3).repeat(x_batch_size, axis=0).reshape(-1, 3)\n",
2854" vis_gt = occlusion_masks[image_indices].reshape(-1, 1)\n",
2855"\n",
2856" state_vis, loss = step(state_vis, x_input, dir_input, vis_gt, i)\n",
2857"\n",
2858" losses.append(loss)\n",
2859"\n",
2860"plt.semilogy(losses)"
2861]
2862},
2863{
2864"cell_type": "code",
2865"execution_count": null,
2866"metadata": {
2867"id": "JtrQf09xzths"
2868},
2869"outputs": [],
2870"source": [
2871"jnp.nansum(occlusion_masks * 2 - 1)"
2872]
2873},
2874{
2875"cell_type": "code",
2876"execution_count": null,
2877"metadata": {
2878"id": "idm3yLifaMRD"
2879},
2880"outputs": [],
2881"source": [
2882"@jax.jit\n",
2883"def evaluate_ind(params, x_input, dir_input, ind):\n",
2884" vis_pred = jax.nn.sigmoid(mlp_vis.apply(params,\n",
2885" posenc(x_input, L_encoding_vis_x),\n",
2886" posenc(dir_input, L_encoding_vis_dir)))\n",
2887"\n",
2888" return vis_pred\n",
2889"\n",
2890"def evaluate(params):\n",
2891" dir_input = omega_xyz\n",
2892"\n",
2893" total_error = 0.0\n",
2894" for ind in range(surface_pts_subsampled.shape[0]):\n",
2895" if (ind + 1) % 1000 == 0:\n",
2896" print(ind + 1)\n",
2897" x_input = surface_pts_subsampled[ind] * jnp.ones_like(omega_xyz)\n",
2898" mask = evaluate_ind(params, x_input, dir_input, ind).reshape(-1)\n",
2899" hard_mask = jnp.float32(mask \u003e 0.5)\n",
2900" mask_gt = occlusion_masks[ind].reshape(-1)\n",
2901"\n",
2902" error = jnp.nansum(jnp.abs(hard_mask - mask_gt))\n",
2903" total_error += error\n",
2904"\n",
2905" return total_error, total_error / surface_pts_subsampled.shape[0]\n",
2906"\n",
2907"ind = 70\n",
2908"\n",
2909"x_input = surface_pts_subsampled[ind] * jnp.ones_like(omega_xyz)\n",
2910"dir_input = omega_xyz\n",
2911"\n",
2912"mask = evaluate_ind(get_params_vis(state_vis), x_input, dir_input, ind).reshape(envmap_H, envmap_W)\n",
2913"plt.imshow(mask)\n",
2914"plt.figure()\n",
2915"plt.imshow(mask \u003e 0.5)\n",
2916"plt.figure()\n",
2917"plt.imshow(occlusion_masks[ind])\n",
2918"\n",
2919"#error, avg_error = evaluate(get_params_vis(state_vis))\n",
2920"\n",
2921"print(avg_error)"
2922]
2923},
2924{
2925"cell_type": "code",
2926"execution_count": null,
2927"metadata": {
2928"id": "tSf39C_s0bSn"
2929},
2930"outputs": [],
2931"source": [
2932"indices = []\n",
2933"for ind in range(10000):\n",
2934" d = ((surface_pts - surface_pts_subsampled[ind]) ** 2).sum(-1)\n",
2935" i, r, c = np.unravel_index(np.argmin(d), shape=surface_pts.shape[:3])\n",
2936" indices.append((i, r, c))\n",
2937"\n",
2938"with open('{DIRECTORY}/subsampling_indices.npy', 'wb') as f:\n",
2939" np.save(f, np.array(indices))\n"
2940]
2941},
2942{
2943"cell_type": "code",
2944"execution_count": null,
2945"metadata": {
2946"id": "LDxFU8Yy01_N"
2947},
2948"outputs": [],
2949"source": [
2950"#with open('{DIRECTORY}/subsampling_indices.npy', 'wb') as f:\n",
2951"# np.save(f, np.array(indices))\n"
2952]
2953},
2954{
2955"cell_type": "code",
2956"execution_count": null,
2957"metadata": {
2958"id": "QtHRjceE6cvi"
2959},
2960"outputs": [],
2961"source": [
2962"with open('{DIRECTORY}/hotdog_surface_normals_subsampled.npy', 'wb') as f:\n",
2963" np.save(f, surface_normals_subsampled)\n"
2964]
2965},
2966{
2967"cell_type": "code",
2968"execution_count": null,
2969"metadata": {
2970"id": "5LCuTu1-mCnm"
2971},
2972"outputs": [],
2973"source": [
2974"normals_gt.shape, indices.shape"
2975]
2976},
2977{
2978"cell_type": "code",
2979"execution_count": null,
2980"metadata": {
2981"id": "lmZ_tVAivWpW"
2982},
2983"outputs": [],
2984"source": [
2985"surface_normals_subsampled = normals_gt.reshape(-1, 128, 128, 3)[indices[:, 0], indices[:, 1], indices[:, 2], :]\n",
2986"surface_normals_subsampled.shape"
2987]
2988},
2989{
2990"cell_type": "code",
2991"execution_count": null,
2992"metadata": {
2993"id": "iVk2YZZhvg5F"
2994},
2995"outputs": [],
2996"source": [
2997"for ind in [0, 10, 20, 50, 100]:\n",
2998"#ind = 10\n",
2999" a = jnp.where(jnp.sum(surface_normals_subsampled[ind] * omega_xyz.reshape(envmap_H, envmap_W, 3), axis=-1) \u003e 0.0, occlusion_masks[ind], 0.0)\n",
3000" \n",
3001" plt.figure(); plt.imshow(a)"
3002]
3003},
3004{
3005"cell_type": "code",
3006"execution_count": null,
3007"metadata": {
3008"id": "U29UKCZAwxC2"
3009},
3010"outputs": [],
3011"source": [
3012"indices[100]"
3013]
3014},
3015{
3016"cell_type": "code",
3017"execution_count": null,
3018"metadata": {
3019"id": "MLrDOPsrwrVt"
3020},
3021"outputs": [],
3022"source": [
3023"surface_normals_subsampled[109]"
3024]
3025},
3026{
3027"cell_type": "code",
3028"execution_count": null,
3029"metadata": {
3030"id": "tY84xcMswyGI"
3031},
3032"outputs": [],
3033"source": [
3034"normals_gt.reshape(512, 128, 128, 3)[327, 96, 49, :]"
3035]
3036},
3037{
3038"cell_type": "code",
3039"execution_count": null,
3040"metadata": {
3041"id": "Ytescsnhvsag"
3042},
3043"outputs": [],
3044"source": [
3045"plt.imshow(occlusion_masks[ind])"
3046]
3047},
3048{
3049"cell_type": "code",
3050"execution_count": null,
3051"metadata": {
3052"id": "HA2YrqMYwHFP"
3053},
3054"outputs": [],
3055"source": [
3056"# A simple script that uses blender to render views of a single object by rotation the camera around it.\n",
3057"# Also produces depth map at the same time.\n",
3058"\n",
3059"import argparse, sys, os\n",
3060"import json\n",
3061"import bpy\n",
3062"import mathutils\n",
3063"import numpy as np\n",
3064" \n",
3065"def listify_matrix(matrix):\n",
3066" matrix_list = []\n",
3067" for row in matrix:\n",
3068" matrix_list.append(list(row))\n",
3069" return matrix_list\n",
3070"\n",
3071"def delistify_matrix(lst):\n",
3072" mat = mathutils.Matrix()\n",
3073" for i in range(4):\n",
3074" for j in range(4):\n",
3075" mat[i][j] = lst[i][j]\n",
3076" return mat\n",
3077"\n",
3078"\n",
3079"DEBUG = False\n",
3080"envmap_H = 50\n",
3081"envmap_W = 99\n",
3082"FORMAT = 'PNG'\n",
3083"\n",
3084"# filename is /.../\u003cmodel\u003e.blend/.../\u003cscript\u003e.py\n",
3085"ind_f = __file__.find('.blend')\n",
3086"ind_i = __file__[:ind_f].rfind('/') + 1\n",
3087"#model_name = __file__[ind_i:ind_f] + '_uniform'\n",
3088"#model_name = 'hotdog_farfield_occlusions_lambertian_new_no_self_occ_uniform_linear_128x128'\n",
3089"model_name = 'hotdog_occlusions_lambertian_linear_128x128'\n",
3090"\n",
3091"# Read from file\n",
3092"#transforms_files = [f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_train.json',\n",
3093"# f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_test.json']\n",
3094"\n",
3095"partitions = ['train']\n",
3096"transforms_files = [f'/Users/dorverbin/Downloads/blend_files/{model_name}/transforms_{partition}.json' for partition in partitions]\n",
3097"RESULTS_PATH = os.path.join(model_name, 'visibility')\n",
3098"\n",
3099"fp = bpy.path.abspath(f\"//{RESULTS_PATH}\")\n",
3100"\n",
3101"if not os.path.exists(fp):\n",
3102" os.makedirs(fp)\n",
3103"for partition in partitions:\n",
3104" if not os.path.exists(os.path.join(fp, partition)):\n",
3105" os.makedirs(os.path.join(fp, partition))\n",
3106"\n",
3107"# Data to store in JSON file\n",
3108"#out_data = {}\n",
3109"\n",
3110"\n",
3111"# Render Optimizations\n",
3112"bpy.context.scene.render.use_persistent_data = True\n",
3113"\n",
3114"\n",
3115"# Set up rendering of depth map.\n",
3116"bpy.context.scene.use_nodes = True\n",
3117"tree = bpy.context.scene.node_tree\n",
3118"links = tree.links\n",
3119"\n",
3120"# Add passes for additionally dumping albedo and normals.\n",
3121"bpy.context.scene.render.image_settings.file_format = str('PNG')\n",
3122"bpy.context.scene.render.image_settings.color_depth = str(8)\n",
3123"print(\"Only 32 if using EXR. When I use binary forget about it\")\n",
3124"\n",
3125"# If using OpenEXR, set to linear color space\n",
3126"bpy.data.scenes['Scene'].display_settings.display_device = 'None'\n",
3127"bpy.data.scenes['Scene'].sequencer_colorspace_settings.name = 'Linear' \n",
3128"\n",
3129"# Remove all tree nodes\n",
3130"for node in tree.nodes:\n",
3131" tree.nodes.remove(node)\n",
3132"\n",
3133"if 'Custom Outputs' not in tree.nodes:\n",
3134" # Create input render layer node.\n",
3135" render_layers = tree.nodes.new('CompositorNodeRLayers')\n",
3136" render_layers.label = 'Custom Outputs'\n",
3137" render_layers.name = 'Custom Outputs'\n",
3138"\n",
3139"# Background\n",
3140"bpy.context.scene.render.dither_intensity = 0.0\n",
3141"bpy.context.scene.render.film_transparent = False\n",
3142"\n",
3143"# Create collection for objects not to render with background\n",
3144"\n",
3145"\n",
3146"scene = bpy.context.scene\n",
3147"scene.render.resolution_x = envmap_W\n",
3148"scene.render.resolution_y = envmap_H\n",
3149"scene.render.resolution_percentage = 100\n",
3150"\n",
3151"cam = scene.objects['Camera']\n",
3152"\n",
3153"# Define equirect camera\n",
3154"cam.data.type = 'PANO'\n",
3155"cam.data.cycles.panorama_type = 'EQUIRECTANGULAR'\n",
3156"#cam.data.cycles.latitude_min = np.pi / (2.0 * envmap_H) - np.pi / 2.0\n",
3157"#cam.data.cycles.latitude_max = np.pi / 2.0 - np.pi / (2.0 * envmap_H)\n",
3158"\n",
3159"\n",
3160"#cam.location = (0, 4.0, 0.5)\n",
3161"\n",
3162"#cam_constraint = cam.constraints.new(type='TRACK_TO')\n",
3163"#cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'\n",
3164"#cam_constraint.up_axis = 'UP_Y'\n",
3165"#b_empty = parent_obj_to_camera(cam)\n",
3166"#cam_constraint.target = b_empty\n",
3167"\n",
3168"\n",
3169"#scene.render.image_settings.file_format = 'PNG' # set output format to .png\n",
3170"scene.render.image_settings.file_format = FORMAT # set output format to .png\n",
3171"\n",
3172"canonical_mat = [[ 0.0, 0.0, -1.0, 0.0],\n",
3173" [ 1.0, 0.0, 0.0, 0.0],\n",
3174" [ 0.0, 1.0, 0.0, 0.0],\n",
3175" [ 0.0, 0.0, 0.0, 1.0]]\n",
3176"\n",
3177"#hotdog_points = np.load('/Users/dorverbin/Downloads/hotdog_surface_pts.npy')\n",
3178"hotdog_points = np.load('/Users/dorverbin/Downloads/hotdog_surface_pts_subsampled.npy')\n",
3179"\n",
3180"all_object_names_except_camera = [k for k in scene.objects.keys() if 'Camera' not in k]\n",
3181"\n",
3182"def toggle_object_visibility(do_hide_objects):\n",
3183" for o in all_object_names_except_camera:\n",
3184" scene.objects[o].hide_render = do_hide_objects\n",
3185" \n",
3186"def toggle_mask_visibility(do_hide_mask):\n",
3187" bpy.data.worlds[\"World\"].node_tree.nodes[\"Math\"].inputs[1].default_value = 1.01 if do_hide_mask else 0.8\n",
3188"\n",
3189"\n",
3190"for transforms_file, partition in zip(transforms_files, partitions):\n",
3191" if transforms_file is not None:\n",
3192" with open(transforms_file) as in_file:\n",
3193" transforms_data = json.load(in_file)\n",
3194" \n",
3195" VIEWS = len(transforms_data['frames'])\n",
3196" else:\n",
3197" raise RuntimeError('Must specify transforms file')\n",
3198"\n",
3199" #out_data['frames'] = []\n",
3200"\n",
3201" #for i in range(0, VIEWS, 20):\n",
3202" #if partition == 'train':\n",
3203" # continue\n",
3204" \"\"\"\n",
3205" for i in range(1):\n",
3206" #cam.matrix_world = delistify_matrix(transforms_data['frames'][i]['transform_matrix']) \n",
3207" #print(cam.matrix_world)\n",
3208"\n",
3209" # Start by rendering mask\n",
3210" cam.matrix_world = delistify_matrix(canonical_mat)\n",
3211"\n",
3212" toggle_object_visibility(False)\n",
3213" toggle_mask_visibility(True)\n",
3214"\n",
3215"\n",
3216" for r in [40]:\n",
3217" for c in range(30, 40):\n",
3218" p = hotdog_points[i, r, c, :]\n",
3219" if not np.all(np.isfinite(p)):\n",
3220" continue\n",
3221" \n",
3222" point_x, point_y, point_z = p\n",
3223" \n",
3224" cam.matrix_world[0][3] = point_x\n",
3225" cam.matrix_world[1][3] = point_y\n",
3226" cam.matrix_world[2][3] = point_z\n",
3227" \n",
3228" scene.render.filepath = os.path.join(fp, partition, f'r_{i}_{r}_{c}') \n",
3229" bpy.ops.render.render(write_still=True) # render still\n",
3230" \n",
3231" \n",
3232" toggle_object_visibility(True)\n",
3233" toggle_mask_visibility(False)\n",
3234" bpy.data.worlds[\"World\"].node_tree.nodes[\"Value\"].outputs[0].default_value = cam.matrix_world[0][3]\n",
3235" bpy.data.worlds[\"World\"].node_tree.nodes[\"Value.001\"].outputs[0].default_value = cam.matrix_world[1][3]\n",
3236" bpy.data.worlds[\"World\"].node_tree.nodes[\"Value.002\"].outputs[0].default_value = cam.matrix_world[2][3] \n",
3237" \n",
3238" \n",
3239" # Render mask\n",
3240" scene.render.filepath = os.path.join(fp, partition, f'r_{i}_mask')\n",
3241" if DEBUG:\n",
3242" break\n",
3243" else:\n",
3244" bpy.ops.render.render(write_still=True) # render still\n",
3245" \"\"\"\n",
3246"\n",
3247" toggle_object_visibility(False)\n",
3248" toggle_mask_visibility(True)\n",
3249"\n",
3250" #for ind in range(hotdog_points.shape[0]):\n",
3251" for ind in [100]: \n",
3252" p = hotdog_points[ind, :]\n",
3253" if not np.all(np.isfinite(p)):\n",
3254" continue\n",
3255" \n",
3256" point_x, point_y, point_z = p\n",
3257" \n",
3258" cam.matrix_world[0][3] = point_x\n",
3259" cam.matrix_world[1][3] = point_y\n",
3260" cam.matrix_world[2][3] = point_z\n",
3261" \n",
3262" #scene.render.filepath = os.path.join(fp, partition, f'r_{ind}') \n",
3263" #bpy.ops.render.render(write_still=True) # render still\n",
3264" \n",
3265"\n",
3266"\n",
3267"\n",
3268" #frame_data = {\n",
3269" # #'file_path': scene.render.filepath,\n",
3270" # 'file_path': f'./{partition}/r_{i}',\n",
3271" # 'transform_matrix': listify_matrix(cam.matrix_world)\n",
3272" #}\n",
3273" #out_data['frames'].append(frame_data)\n",
3274"\n",
3275" #if transforms_file is None:\n",
3276" # b_empty.rotation_euler[0] = CIRCLE_FIXED_START[0] + (np.cos(radians(stepsize*i))+1)/2 * vertical_diff\n",
3277" # b_empty.rotation_euler[2] += radians(2*stepsize)\n",
3278"\n",
3279" #if not DEBUG:\n",
3280" # with open(fp + '/' + f'transforms_{partition}.json', 'w') as out_file:\n",
3281" # json.dump(out_data, out_file, indent=4)\n",
3282"\n",
3283"\n"
3284]
3285},
3286{
3287"cell_type": "code",
3288"execution_count": null,
3289"metadata": {
3290"id": "DvcjVwuz-Q1b"
3291},
3292"outputs": [],
3293"source": [
3294"# A simple script that uses blender to render views of a single object by rotation the camera around it.\n",
3295"# Also produces depth map at the same time.\n",
3296"\n",
3297"import argparse, sys, os\n",
3298"import json\n",
3299"import bpy\n",
3300"import mathutils\n",
3301"import numpy as np\n",
3302" \n",
3303"\n",
3304"def listify_matrix(matrix):\n",
3305" matrix_list = []\n",
3306" for row in matrix:\n",
3307" matrix_list.append(list(row))\n",
3308" return matrix_list\n",
3309"\n",
3310"def delistify_matrix(lst):\n",
3311" mat = mathutils.Matrix()\n",
3312" for i in range(4):\n",
3313" for j in range(4):\n",
3314" mat[i][j] = lst[i][j]\n",
3315" return mat\n",
3316" \n",
3317"def parse_bin(s):\n",
3318" return int(s[1:], 2) / 2.**(len(s) - 1)\n",
3319"\n",
3320"\n",
3321"def phi2(i):\n",
3322" return parse_bin('.' + f'{i:b}'[::-1])\n",
3323"\n",
3324"def nice_uniform(N):\n",
3325" u = []\n",
3326" v = []\n",
3327" for i in range(N):\n",
3328" u.append(i / float(N))\n",
3329" v.append(phi2(i))\n",
3330" #pts.append((i/float(N), phi2(i)))\n",
3331"\n",
3332" return u, v\n",
3333"\n",
3334"def nice_uniform_spherical(N, hemisphere=True):\n",
3335" \"\"\"implementation of http://holger.dammertz.org/stuff/notes_HammersleyOnHemisphere.html\"\"\"\n",
3336" u, v = nice_uniform(N)\n",
3337"\n",
3338" theta = np.arccos(1.0 - np.array(u)) * (2.0 - int(hemisphere))\n",
3339" phi = 2.0 * np.pi * np.array(v)\n",
3340"\n",
3341" return theta, phi\n",
3342" \n",
3343" \n",
3344" \n",
3345"hemisphere = True\n",
3346"camera_dist = np.sqrt(4.0**2 + 0.5**2)\n",
3347"def get_all_camera_matrices(N_cameras, camera_dist=camera_dist):\n",
3348" theta, phi = nice_uniform_spherical(N_cameras, hemisphere)\n",
3349"\n",
3350" camera_x_vec = np.sin(theta) * np.cos(phi)\n",
3351" camera_y_vec = np.sin(theta) * np.sin(phi)\n",
3352" camera_z_vec = np.cos(theta)\n",
3353"\n",
3354" cameras = []\n",
3355" for i in range(N_cameras):\n",
3356" camera = np.eye(4)\n",
3357" camera[0, 3] = camera_x_vec[i] * camera_dist\n",
3358" camera[1, 3] = camera_y_vec[i] * camera_dist\n",
3359" camera[2, 3] = camera_z_vec[i] * camera_dist\n",
3360"\n",
3361" zdir = np.array([camera_x_vec[i], camera_y_vec[i], camera_z_vec[i]])\n",
3362" zdir /= np.linalg.norm(zdir)\n",
3363"\n",
3364" ydir = np.array([0.0, 0.0, 1.0])\n",
3365" ydir -= zdir * zdir.dot(ydir)\n",
3366" ydir[0] += 1e-10 # make sure that cameras pointing straight down/up have a defined ydir\n",
3367" ydir /= np.linalg.norm(ydir)\n",
3368"\n",
3369" xdir = np.cross(ydir, zdir)\n",
3370"\n",
3371" camera[:3, 0] = xdir\n",
3372" camera[:3, 1] = ydir\n",
3373" camera[:3, 2] = zdir\n",
3374" \n",
3375" cameras.append(camera)\n",
3376" return cameras\n",
3377" \n",
3378"DEBUG = False\n",
3379"VIEWS = 512 # Only used if not specifying transforms_file\n",
3380"RESOLUTION = 128 #800\n",
3381"DEPTH_SCALE = 1.4\n",
3382"FORMAT = 'OPEN_EXR'\n",
3383"COLOR_DEPTH = 8 if FORMAT == 'PNG' else 32\n",
3384" \n",
3385" \n",
3386"# filename is /.../\u003cmodel\u003e.blend/.../\u003cscript\u003e.py\n",
3387"ind_f = __file__.find('.blend')\n",
3388"ind_i = __file__[:ind_f].rfind('/') + 1\n",
3389"model_name = __file__[ind_i:ind_f] + '_uniform'\n",
3390"\n",
3391"\n",
3392"if FORMAT == 'OPEN_EXR':\n",
3393" RESULTS_PATH = f'{model_name}_linear'\n",
3394"elif FORMAT == 'PNG':\n",
3395" RESULTS_PATH = model_name\n",
3396"else:\n",
3397" raise RuntimeError('format unknown')\n",
3398"\n",
3399"if RESOLUTION != 800:\n",
3400" RESULTS_PATH += f'_{RESOLUTION}x{RESOLUTION}'\n",
3401"\n",
3402"# Read from file\n",
3403"#transforms_files = [f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_train.json',\n",
3404"# f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_test.json']\n",
3405"#partitions = ['train', 'test']\n",
3406"transforms_files = [None]\n",
3407"partitions = ['occlusions']\n",
3408"\n",
3409"\n",
3410"fp = bpy.path.abspath(f\"//{RESULTS_PATH}\")\n",
3411"\n",
3412"if not os.path.exists(fp):\n",
3413" os.makedirs(fp)\n",
3414"for partition in partitions:\n",
3415" if not os.path.exists(os.path.join(fp, partition)):\n",
3416" os.makedirs(os.path.join(fp, partition))\n",
3417"\n",
3418"# Data to store in JSON file\n",
3419"out_data = {\n",
3420" 'camera_angle_x': bpy.data.objects['Camera'].data.angle_x,\n",
3421"}\n",
3422"\n",
3423"# Render Optimizations\n",
3424"bpy.context.scene.render.use_persistent_data = True\n",
3425"\n",
3426"\n",
3427"# Set up rendering of depth map.\n",
3428"bpy.context.scene.use_nodes = True\n",
3429"tree = bpy.context.scene.node_tree\n",
3430"links = tree.links\n",
3431"\n",
3432"# Add passes for additionally dumping albedo and normals.\n",
3433"bpy.context.scene.view_layers[\"RenderLayer\"].use_pass_normal = True\n",
3434"bpy.context.scene.render.image_settings.file_format = str(FORMAT)\n",
3435"bpy.context.scene.render.image_settings.color_depth = str(COLOR_DEPTH)\n",
3436"\n",
3437"# If using OpenEXR, set to linear color space\n",
3438"if FORMAT == 'OPEN_EXR':\n",
3439" bpy.data.scenes['Scene'].display_settings.display_device = 'None'\n",
3440" bpy.data.scenes['Scene'].sequencer_colorspace_settings.name = 'Linear'\n",
3441"else:\n",
3442" bpy.data.scenes['Scene'].display_settings.display_device = 'sRGB'\n",
3443" bpy.data.scenes['Scene'].sequencer_colorspace_settings.name = 'sRGB' \n",
3444"\n",
3445"# Remove all tree nodes\n",
3446"for node in tree.nodes:\n",
3447" tree.nodes.remove(node)\n",
3448"\n",
3449"if 'Custom Outputs' not in tree.nodes:\n",
3450" # Create input render layer node.\n",
3451" render_layers = tree.nodes.new('CompositorNodeRLayers')\n",
3452" render_layers.label = 'Custom Outputs'\n",
3453" render_layers.name = 'Custom Outputs'\n",
3454" \n",
3455" depth_file_output = tree.nodes.new(type=\"CompositorNodeOutputFile\")\n",
3456" depth_file_output.label = 'Depth Output'\n",
3457" depth_file_output.name = 'Depth Output'\n",
3458" if FORMAT == 'OPEN_EXR':\n",
3459" add_one = tree.nodes.new('CompositorNodeMath')\n",
3460" add_one.operation = 'ADD'\n",
3461" add_one.inputs[1].default_value = 1.0\n",
3462" links.new(render_layers.outputs['Depth'], add_one.inputs[0])\n",
3463" \n",
3464" recip = tree.nodes.new('CompositorNodeMath')\n",
3465" recip.operation = 'DIVIDE'\n",
3466" recip.inputs[0].default_value = 1.0\n",
3467" links.new(add_one.outputs[0], recip.inputs[1])\n",
3468" \n",
3469" links.new(recip.outputs[0], depth_file_output.inputs[0])\n",
3470" \n",
3471" else:\n",
3472" # Remap as other types can not represent the full range of depth.\n",
3473" map = tree.nodes.new(type=\"CompositorNodeMapRange\")\n",
3474" # Size is chosen kind of arbitrarily, try out until you're satisfied with resulting depth map.\n",
3475" map.inputs['From Min'].default_value = 0\n",
3476" map.inputs['From Max'].default_value = 8\n",
3477" map.inputs['To Min'].default_value = 1\n",
3478" map.inputs['To Max'].default_value = 0\n",
3479" links.new(render_layers.outputs['Depth'], map.inputs[0])\n",
3480"\n",
3481" links.new(map.outputs[0], depth_file_output.inputs[0])\n",
3482" \n",
3483" normal_file_output = tree.nodes.new(type=\"CompositorNodeOutputFile\")\n",
3484" normal_file_output.label = 'Normal Output'\n",
3485" normal_file_output.name = 'Normal Output'\n",
3486" normal_file_output.format.file_format = 'PNG'\n",
3487" \n",
3488" # Separate normals into channels, transform (x+1)/2 and combine\n",
3489" sep_rgba = tree.nodes.new('CompositorNodeSepRGBA')\n",
3490" links.new(render_layers.outputs['Normal'], sep_rgba.inputs[0])\n",
3491" \n",
3492" comb_rgba = tree.nodes.new('CompositorNodeCombRGBA')\n",
3493" add_ones = []\n",
3494" divide_by_twos = []\n",
3495" for i in range(3):\n",
3496" add_ones.append(tree.nodes.new('CompositorNodeMath'))\n",
3497" add_ones[i].operation = 'ADD'\n",
3498" add_ones[i].inputs[1].default_value = 1.0\n",
3499" links.new(sep_rgba.outputs[i], add_ones[i].inputs[0]) \n",
3500" \n",
3501" divide_by_twos.append(tree.nodes.new('CompositorNodeMath'))\n",
3502" divide_by_twos[i].operation = 'DIVIDE'\n",
3503" divide_by_twos[i].inputs[1].default_value = 2.0\n",
3504" links.new(add_ones[i].outputs[0], divide_by_twos[i].inputs[0]) \n",
3505" \n",
3506" links.new(divide_by_twos[i].outputs[0], comb_rgba.inputs[i])\n",
3507" \n",
3508" # Connect alpha\n",
3509" links.new(sep_rgba.outputs[3], comb_rgba.inputs[3])\n",
3510" \n",
3511" links.new(comb_rgba.outputs[0], normal_file_output.inputs[0])\n",
3512"\n",
3513"# Background\n",
3514"bpy.context.scene.render.dither_intensity = 0.0\n",
3515"bpy.context.scene.render.film_transparent = True\n",
3516"\n",
3517"# Create collection for objects not to render with background\n",
3518"\n",
3519" \n",
3520"objs = [ob for ob in bpy.context.scene.objects if ob.type in ('EMPTY') and 'Empty' in ob.name]\n",
3521"bpy.ops.object.delete({\"selected_objects\": objs})\n",
3522"\n",
3523"def parent_obj_to_camera(b_camera):\n",
3524" origin = (0, 0, 0)\n",
3525" b_empty = bpy.data.objects.new(\"Empty\", None)\n",
3526" b_empty.location = origin\n",
3527" b_camera.parent = b_empty # setup parenting\n",
3528"\n",
3529" scn = bpy.context.scene\n",
3530" scn.collection.objects.link(b_empty)\n",
3531" bpy.context.view_layer.objects.active = b_empty\n",
3532" # scn.objects.active = b_empty\n",
3533" return b_empty\n",
3534"\n",
3535"\n",
3536"scene = bpy.context.scene\n",
3537"scene.render.resolution_x = RESOLUTION\n",
3538"scene.render.resolution_y = RESOLUTION\n",
3539"scene.render.resolution_percentage = 100\n",
3540"\n",
3541"cam = scene.objects['Camera']\n",
3542"#cam.location = (0, 4.0, 0.5)\n",
3543"\n",
3544"#cam_constraint = cam.constraints.new(type='TRACK_TO')\n",
3545"#cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'\n",
3546"#cam_constraint.up_axis = 'UP_Y'\n",
3547"#b_empty = parent_obj_to_camera(cam)\n",
3548"#cam_constraint.target = b_empty\n",
3549"\n",
3550"\n",
3551"#scene.render.image_settings.file_format = 'PNG' # set output format to .png\n",
3552"scene.render.image_settings.file_format = FORMAT # set output format to .png\n",
3553"\n",
3554"from math import radians\n",
3555"\n",
3556"stepsize = 360.0 / VIEWS\n",
3557"rotation_mode = 'XYZ'\n",
3558"\n",
3559"\n",
3560"\n",
3561"\n",
3562"if not DEBUG:\n",
3563" for output_node in [tree.nodes['Depth Output'], tree.nodes['Normal Output']]:\n",
3564" output_node.base_path = ''\n",
3565"\n",
3566"for transforms_file, partition in zip(transforms_files, partitions):\n",
3567" if transforms_file is not None:\n",
3568" with open(transforms_file) as in_file:\n",
3569" transforms_data = json.load(in_file)\n",
3570" \n",
3571" VIEWS = len(transforms_data['frames'])\n",
3572" else:\n",
3573" cameras = get_all_camera_matrices(VIEWS)\n",
3574"\n",
3575" out_data['frames'] = []\n",
3576"\n",
3577" #for i in range(0, VIEWS, 20):\n",
3578" #if partition == 'train':\n",
3579" # continue\n",
3580" print(VIEWS)\n",
3581" for i in [350]:\n",
3582" if transforms_file is None:\n",
3583" for a in range(4):\n",
3584" for b in range(4):\n",
3585" cam.matrix_world[a][b] = cameras[i][a][b]\n",
3586" print(cameras[i], i)\n",
3587" #cam.location.x = cameras[i][0][3]\n",
3588" #cam.location.y = cameras[i][1][3]\n",
3589" #cam.location.z = cameras[i][2][3]\n",
3590" #if RANDOM_VIEWS:\n",
3591" # scene.render.filepath = fp + '/r_' + str(i)\n",
3592" # b_empty.rotation_euler = np.random.uniform(0, 2*np.pi, size=3)\n",
3593" #else:\n",
3594" # print(\"Rotation {}, {}\".format((stepsize * i), radians(stepsize * i)))\n",
3595" # scene.render.filepath = fp + '/r_{0:03d}'.format(int(i * stepsize))\n",
3596" else:\n",
3597" cam.matrix_world = delistify_matrix(transforms_data['frames'][i]['transform_matrix']) \n",
3598" print(cam.matrix_world)\n",
3599" #peter.matrix_world = cam.matrix_world\n",
3600" if DEBUG:\n",
3601" i = np.random.randint(0,VIEWS)\n",
3602" b_empty.rotation_euler[0] = CIRCLE_FIXED_START[0] + (np.cos(radians(stepsize*i))+1)/2 * vertical_diff\n",
3603" b_empty.rotation_euler[2] += radians(2*stepsize*i)\n",
3604" \n",
3605" bpy.data.worlds[\"World\"].node_tree.nodes[\"Value\"].outputs[0].default_value = cam.matrix_world[0][3]\n",
3606" bpy.data.worlds[\"World\"].node_tree.nodes[\"Value.001\"].outputs[0].default_value = cam.matrix_world[1][3]\n",
3607" bpy.data.worlds[\"World\"].node_tree.nodes[\"Value.002\"].outputs[0].default_value = cam.matrix_world[2][3] \n",
3608" \n",
3609" print(\"Rotation {}, {}\".format((stepsize * i), radians(stepsize * i)))\n",
3610" scene.render.filepath = os.path.join(fp, partition, f'r_{i}')\n",
3611"\n",
3612" tree.nodes['Depth Output'].file_slots[0].path = scene.render.filepath + \"_disp_\"\n",
3613" tree.nodes['Normal Output'].file_slots[0].path = scene.render.filepath + \"_normal_\"\n",
3614"\n",
3615" break\n",
3616" if DEBUG:\n",
3617" break\n",
3618" else:\n",
3619" bpy.ops.render.render(write_still=True) # render still\n",
3620"\n",
3621" frame_data = {\n",
3622" #'file_path': scene.render.filepath,\n",
3623" 'file_path': f'./{partition}/r_{i}',\n",
3624" 'rotation': radians(stepsize),\n",
3625" 'transform_matrix': listify_matrix(cam.matrix_world)\n",
3626" }\n",
3627" out_data['frames'].append(frame_data)\n",
3628"\n",
3629" #if transforms_file is None:\n",
3630" # b_empty.rotation_euler[0] = CIRCLE_FIXED_START[0] + (np.cos(radians(stepsize*i))+1)/2 * vertical_diff\n",
3631" # b_empty.rotation_euler[2] += radians(2*stepsize)\n",
3632"\n",
3633" if not DEBUG:\n",
3634" with open(fp + '/' + f'transforms_{partition}.json', 'w') as out_file:\n",
3635" json.dump(out_data, out_file, indent=4)\n"
3636]
3637}
3638],
3639"metadata": {
3640"colab": {
3641"collapsed_sections": [],
3642"last_runtime": {
3643"build_target": "//googlex/gcam/buff/mipnerf360:notebook",
3644"kind": "private"
3645},
3646"name": "spherical_peter_problem_blender.ipynb",
3647"private_outputs": true,
3648"provenance": [
3649{
3650"file_id": "1Zys9YweuFO3a-1IbJNqMEYFcMgcI5-kH",
3651"timestamp": 1658639749060
3652}
3653]
3654},
3655"kernelspec": {
3656"display_name": "Python 3",
3657"name": "python3"
3658},
3659"language_info": {
3660"name": "python"
3661}
3662},
3663"nbformat": 4,
3664"nbformat_minor": 0
3665}
3666