google-research

Форк
0
/
spherical_peter_problem_blender.ipynb 
3665 строк · 132.9 Кб
1
{
2
  "cells": [
3
    {
4
      "cell_type": "markdown",
5
      "metadata": {
6
        "id": "JndnmDMp66FL"
7
      },
8
      "source": [
9
        "Copyright 2018 Google LLC.\n",
10
        "\n",
11
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
12
      ]
13
    },
14
    {
15
      "cell_type": "code",
16
      "execution_count": null,
17
      "metadata": {
18
        "id": "hMqWDc_m6rUC"
19
      },
20
      "outputs": [],
21
      "source": [
22
        "#@title Default title text\n",
23
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
24
        "# you may not use this file except in compliance with the License.\n",
25
        "# You may obtain a copy of the License at\n",
26
        "#\n",
27
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
28
        "#\n",
29
        "# Unless required by applicable law or agreed to in writing, software\n",
30
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
31
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
32
        "# See the License for the specific language governing permissions and\n",
33
        "# limitations under the License."
34
      ]
35
    },
36
    {
37
      "cell_type": "code",
38
      "execution_count": null,
39
      "metadata": {
40
        "id": "u3E9yyHoT8KO"
41
      },
42
      "outputs": [],
43
      "source": [
44
        "import jax.numpy as jnp\n",
45
        "import jax\n",
46
        "import numpy as np\n",
47
        "from PIL import Image\n",
48
        "import matplotlib.pyplot as plt\n",
49
        "import plotly.graph_objects as go\n",
50
        "import functools\n",
51
        "import jax.experimental.optimizers\n",
52
        "import time\n",
53
        "import flax\n",
54
        "import flax.linen as nn\n",
55
        "from typing import Sequence, Callable\n",
56
        "from IPython.display import clear_output\n",
57
        "import cv2\n",
58
        "import imageio\n",
59
        "import mediapy as media\n",
60
        "\n",
61
        "import os\n",
62
        "\n"
63
      ]
64
    },
65
    {
66
      "cell_type": "code",
67
      "execution_count": null,
68
      "metadata": {
69
        "id": "qK31kBAtIJ3n"
70
      },
71
      "outputs": [],
72
      "source": [
73
        "def linear_to_srgb(linear):\n",
74
        "  \"\"\"Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.\"\"\"\n",
75
        "  eps = jnp.finfo(jnp.float32).eps\n",
76
        "  srgb0 = 323 / 25 * linear\n",
77
        "  srgb1 = (211 * jnp.maximum(eps, linear)**(5 / 12) - 11) / 200\n",
78
        "  return jnp.where(linear \u003c= 0.0031308, srgb0, srgb1)\n",
79
        "\n",
80
        "def read_envmap(filename):\n",
81
        "  with open(filename, 'rb') as f:\n",
82
        "    return imageio.imread(f, 'exr')\n",
83
        "\n",
84
        "#envmap_linear = read_envmap(f'{DIRECTORY}}/ninomaru_teien_4k.exr')\n",
85
        "#envmap_linear = read_envmap(f'{DIRECTORY}/spruit_sunrise_4k.exr')\n",
86
        "#envmap_linear = read_envmap(f'{DIRECTORY}/hotel_room_4k.exr')\n",
87
        "#envmap_linear = read_envmap(f'{DIRECTORY}/spruit_sunrise_50x99.exr')\n",
88
        "envmap_linear = read_envmap(f'{DIRECTORY}/hotel_room_50x99.exr')\n",
89
        "\n",
90
        "envmap_linear = np.fliplr(envmap_linear)  # Blender flips this for some reason\n",
91
        "#envmap_linear = np.roll(envmap_linear, envmap_linear.shape[1]//2, axis=1)\n",
92
        "envmap_srgb = linear_to_srgb(envmap_linear)\n",
93
        "\n",
94
        "plt.imshow(envmap_srgb)"
95
      ]
96
    },
97
    {
98
      "cell_type": "code",
99
      "execution_count": null,
100
      "metadata": {
101
        "id": "hvG8F8ADKMHb"
102
      },
103
      "outputs": [],
104
      "source": [
105
        "envmap_H = 50\n",
106
        "envmap_W = envmap_H * 2 - 1\n",
107
        "#envmap_H, envmap_W = envmap_linear.shape[:2]\n",
108
        "\n",
109
        "envmap_gt = envmap_linear\n",
110
        "#envmap_gt = linear_to_srgb(cv2.resize(envmap_linear, dsize=(envmap_W, envmap_H), interpolation=cv2.INTER_AREA))\n",
111
        "#envmap_gt = cv2.resize(envmap_linear, dsize=(envmap_W, envmap_H), interpolation=cv2.INTER_AREA)\n",
112
        "plt.imshow(envmap_gt)\n",
113
        "plt.axis('off')"
114
      ]
115
    },
116
    {
117
      "cell_type": "code",
118
      "execution_count": null,
119
      "metadata": {
120
        "id": "O6RPcOJFowWR"
121
      },
122
      "outputs": [],
123
      "source": [
124
        "#with open(f'{DIRECTORY}/hotel_room_{envmap_H}x{envmap_W}.exr', 'wb') as f:\n",
125
        "#  imageio.imsave(f, envmap_gt, 'exr')\n"
126
      ]
127
    },
128
    {
129
      "cell_type": "code",
130
      "execution_count": null,
131
      "metadata": {
132
        "id": "hrCaejuLFcpI"
133
      },
134
      "outputs": [],
135
      "source": [
136
        "#envmap_H, envmap_W = envmap_linear.shape[:2]\n",
137
        "#omega_phi, omega_theta = jnp.meshgrid(jnp.linspace(-jnp.pi, jnp.pi, envmap_W+1)[:-1],\n",
138
        "#                                      jnp.linspace(0.0,     jnp.pi, envmap_H+1)[:-1])\n",
139
        "omega_phi, omega_theta = jnp.meshgrid(jnp.linspace(-jnp.pi, jnp.pi, envmap_W+1)[:-1] + 2.0 * jnp.pi / (2.0 * envmap_W),\n",
140
        "                                      jnp.linspace(0.0,     jnp.pi, envmap_H+1)[:-1] +       jnp.pi / (2.0 * envmap_H))\n",
141
        "\n",
142
        "dtheta_dphi = (omega_theta[1, 1] - omega_theta[0, 0]) * (omega_phi[1, 1] - omega_phi[0, 0])\n",
143
        "\n",
144
        "omega_theta = omega_theta.flatten()\n",
145
        "omega_phi = omega_phi.flatten()\n",
146
        "\n",
147
        "omega_x = jnp.sin(omega_theta) * jnp.cos(omega_phi)\n",
148
        "omega_y = jnp.sin(omega_theta) * jnp.sin(omega_phi)\n",
149
        "omega_z = jnp.cos(omega_theta)\n",
150
        "omega_xyz = jnp.stack([omega_x,\n",
151
        "                       omega_y,\n",
152
        "                       omega_z], axis=-1)\n",
153
        "\n"
154
      ]
155
    },
156
    {
157
      "cell_type": "code",
158
      "execution_count": null,
159
      "metadata": {
160
        "id": "gJbvom3l620T"
161
      },
162
      "outputs": [],
163
      "source": [
164
        "def mse_to_psnr(mse):\n",
165
        "  \"\"\"Compute PSNR given an MSE (we assume the maximum pixel value is 1).\"\"\"\n",
166
        "  return -10. / jnp.log(10.) * jnp.log(mse)\n",
167
        "  \n",
168
        "def get_rays(H, W, focal, c2w, rand_ort=False, key=None):\n",
169
        "  \"\"\"\n",
170
        "  c2w: 4x4 matrix\n",
171
        "  output: two arrays of shape [H, W, 3]\n",
172
        "  \"\"\"\n",
173
        "  j, i = jnp.meshgrid(jnp.arange(W, dtype=jnp.float32),\n",
174
        "                      jnp.arange(H, dtype=jnp.float32))\n",
175
        "  \n",
176
        "  if rand_ort:\n",
177
        "    k1, k2 = random.split(key)\n",
178
        "\n",
179
        "    i += jax.random.uniform(k1, shape=(H, W)) - 0.5\n",
180
        "    j += jax.random.uniform(k2, shape=(H, W)) - 0.5\n",
181
        "      \n",
182
        "  dirs = jnp.stack([ (j.flatten()-0.5*W)/focal,\n",
183
        "                    -(i.flatten()-0.5*H)/focal,\n",
184
        "                    -jnp.ones((H*W,), dtype=jnp.float32)], -1)  # shape [HW, 3]\n",
185
        "  \n",
186
        "  rays_d = dirs @ c2w[:3, :3].T  # shape [HW, 3]\n",
187
        "  rays_o = c2w[:3,-1:].T.repeat(H*W, 0)\n",
188
        "  return rays_o.reshape(H, W, 3), rays_d.reshape(H, W, 3)\n"
189
      ]
190
    },
191
    {
192
      "cell_type": "code",
193
      "execution_count": null,
194
      "metadata": {
195
        "id": "bbHRuTQL7GZS"
196
      },
197
      "outputs": [],
198
      "source": [
199
        "def parse_bin(s):\n",
200
        "  return int(s[1:], 2) / 2.**(len(s) - 1)\n",
201
        "\n",
202
        "\n",
203
        "def phi2(i):\n",
204
        "  return parse_bin('.' + f'{i:b}'[::-1])\n",
205
        "\n",
206
        "def nice_uniform(N):\n",
207
        "  u = []\n",
208
        "  v = []\n",
209
        "  for i in range(N):\n",
210
        "    u.append(i / float(N))\n",
211
        "    v.append(phi2(i))\n",
212
        "    #pts.append((i/float(N), phi2(i)))\n",
213
        "\n",
214
        "  return u, v\n",
215
        "\n",
216
        "def nice_uniform_spherical(N, hemisphere=True):\n",
217
        "  \"\"\"implementation of http://holger.dammertz.org/stuff/notes_HammersleyOnHemisphere.html\"\"\"\n",
218
        "  u, v = nice_uniform(N)\n",
219
        "\n",
220
        "  theta = np.arccos(1.0 - np.array(u)) * (2.0 - int(hemisphere))\n",
221
        "  phi   = 2.0 * np.pi * np.array(v)\n",
222
        "\n",
223
        "  return theta, phi\n",
224
        "    \n",
225
        "hemisphere = True\n",
226
        "def get_all_camera_rays(N_cameras, camera_dist, H, W, focal):\n",
227
        "  theta, phi = nice_uniform_spherical(N_cameras, hemisphere)\n",
228
        "\n",
229
        "  camera_x_vec = np.sin(theta) * np.cos(phi)\n",
230
        "  camera_y_vec = np.sin(theta) * np.sin(phi)\n",
231
        "  camera_z_vec = np.cos(theta)\n",
232
        "\n",
233
        "  rays_o_vec = []\n",
234
        "  rays_d_vec = []\n",
235
        "  cameras = []\n",
236
        "  for i in range(N_cameras):\n",
237
        "    camera = np.eye(4)\n",
238
        "    camera[0, 3] = camera_x_vec[i] * camera_dist\n",
239
        "    camera[1, 3] = camera_y_vec[i] * camera_dist\n",
240
        "    camera[2, 3] = camera_z_vec[i] * camera_dist\n",
241
        "\n",
242
        "    zdir = np.array([camera_x_vec[i], camera_y_vec[i], camera_z_vec[i]])\n",
243
        "    zdir /= np.linalg.norm(zdir)\n",
244
        "\n",
245
        "    ydir = np.array([0.0, 0.0, 1.0])\n",
246
        "    ydir -= zdir * zdir.dot(ydir)\n",
247
        "    ydir[0] += 1e-10  # make sure that cameras pointing straight down/up have a defined ydir\n",
248
        "    ydir /= np.linalg.norm(ydir)\n",
249
        "\n",
250
        "    xdir = np.cross(ydir, zdir)\n",
251
        "\n",
252
        "\n",
253
        "    camera[:3, 0] = xdir\n",
254
        "    camera[:3, 1] = ydir\n",
255
        "    camera[:3, 2] = zdir\n",
256
        "\n",
257
        "    cameras.append(camera)\n",
258
        "\n",
259
        "    rays_o, rays_d = get_rays(H, W, focal, camera)\n",
260
        "\n",
261
        "    rays_o_vec.append(rays_o)\n",
262
        "    rays_d_vec.append(rays_d)\n",
263
        "\n",
264
        "  rays_o_vec = jnp.stack(rays_o_vec, 0)\n",
265
        "  rays_d_vec = jnp.stack(rays_d_vec, 0)\n",
266
        "\n",
267
        "  return rays_o_vec, rays_d_vec"
268
      ]
269
    },
270
    {
271
      "cell_type": "code",
272
      "execution_count": null,
273
      "metadata": {
274
        "id": "lsBMjzCg65dk"
275
      },
276
      "outputs": [],
277
      "source": [
278
        "def render_pixel(normal, lobe, envmap, mask):\n",
279
        "  masked_envmap = envmap * mask[:, :, None]\n",
280
        "  return (masked_envmap * lobe * jnp.sin(omega_theta).reshape(envmap_H, envmap_W, 1)).sum(0).sum(0) * dtheta_dphi / jnp.pi\n",
281
        "\n",
282
        "\n",
283
        "def render(envmap, mask, materials, normals, rays_d, alpha, shading='lambertian'):\n",
284
        "  \"\"\"\n",
285
        "  envmap:     shape [h, w, 3]\n",
286
        "  mask:       shape [h, w]\n",
287
        "  materials:  dictionary with entries of shape [N, 3]\n",
288
        "  normals:    shape [N, 3]\n",
289
        "  rays_d:     shape [N, 3]\n",
290
        "  alpha:      shape [N, 1]\n",
291
        "  \n",
292
        "  output: rendered colors, shape [N, 3]\n",
293
        "  \"\"\"\n",
294
        "  \n",
295
        "  assert shading in ['lambertian', 'phong', 'blinnphong']\n",
296
        "\n",
297
        "  if shading in ['lambertian', 'phong', 'blinnphong']:\n",
298
        "    # TODO: Feed in only pixels where alpha = 1\n",
299
        "    lobes = jnp.maximum(0.0, (omega_xyz.reshape(1, envmap_H, envmap_W, 3) * normals[:, None, None, :]).sum(-1, keepdims=True)) * materials['albedo'][:, None, None, :]  # [HW, envmap_H, envmap_W, 3]\n",
300
        "\n",
301
        "  if shading == 'blinnphong':\n",
302
        "    assert 'specular_albedo' in materials.keys()\n",
303
        "    specular_albedo = materials['specular_albedo'][:, None, None, :]\n",
304
        "    exponent = materials['specular_exponent'][:, None, None, :]\n",
305
        "\n",
306
        "    d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)\n",
307
        "    rays_d_norm = -rays_d / jnp.sqrt(d_norm_sq + 1e-10)\n",
308
        "\n",
309
        "    halfvectors = omega_xyz.reshape(1, envmap_H, envmap_W, 3) + rays_d_norm[:, None, None, :]\n",
310
        "    halfvectors /= (jnp.linalg.norm(halfvectors, axis=-1, keepdims=True) + 1e-10)  # [N, envmap_H, envmap_W, 3]\n",
311
        "\n",
312
        "    lobes += jnp.maximum(0.0, (halfvectors * normals[:, None, None, :]).sum(-1, keepdims=True)) ** exponent * specular_albedo\n",
313
        "\n",
314
        "  if shading == 'phong':\n",
315
        "    assert 'specular_albedo' in materials.keys()\n",
316
        "    specular_albedo = materials['specular_albedo'][:, None, None, :]\n",
317
        "    exponent = materials['specular_exponent'][:, None, None, :]\n",
318
        "\n",
319
        "    d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)\n",
320
        "    rays_d_norm = -rays_d / jnp.sqrt(d_norm_sq + 1e-10)  # [N, 3]\n",
321
        "\n",
322
        "    refdirs = 2.0 * (normals * rays_d_norm).sum(-1, keepdims=True) * normals - rays_d_norm   # [N, 3]\n",
323
        "    refdirs = refdirs[:, None, None, :]\n",
324
        "\n",
325
        "    # No need to normalize because ||n|| = 1 and ||d|| = 1, so ||2(n.d)n - d|| = 1.\n",
326
        "    print(\"Not normalizing here (because unnecessary, at least theoretically).\")\n",
327
        "    #refdirs /= (jnp.linalg.norm(refdirs, axis=-1, keepdims=True) + 1e-10)  # [N, HW, envmap_H, envmap_W, 3]\n",
328
        "\n",
329
        "    lobes += jnp.maximum(0.0, (refdirs * omega_xyz.reshape(1, envmap_H, envmap_W, 3)).sum(-1, keepdims=True)) ** exponent * specular_albedo\n",
330
        "     \n",
331
        "  colors = jax.vmap(render_pixel, in_axes=(0, 0, None, None))(normals, lobes, envmap, mask)     \n",
332
        "  \n",
333
        "  return colors * alpha"
334
      ]
335
    },
336
    {
337
      "cell_type": "code",
338
      "execution_count": null,
339
      "metadata": {
340
        "id": "dq8Nmn1i9Dnf"
341
      },
342
      "outputs": [],
343
      "source": [
344
        "def load_img(pth: str, is_16bit: bool=False) -\u003e np.ndarray:\n",
345
        "  \"\"\"Load an image and cast to float32.\"\"\"\n",
346
        "  with utils.open_file(pth, 'rb') as f:\n",
347
        "    if is_16bit:\n",
348
        "      bytes_ = np.asarray(bytearray(f.read()), dtype=np.uint8)  # Read bytes\n",
349
        "      image = np.array(\n",
350
        "          cv2.imdecode(bytes_, cv2.IMREAD_UNCHANGED), dtype=np.float32)\n",
351
        "    else:\n",
352
        "      image = np.array(Image.open(f), dtype=np.float32)\n",
353
        "  return image\n",
354
        "\n",
355
        "#disp = load_img(os.path.join(data_dir, 'test', 'r_0_disp.tiff'), is_16bit=True)[:, :, :1] / 255.0\n",
356
        "#plt.imshow(1/disp-1)"
357
      ]
358
    },
359
    {
360
      "cell_type": "code",
361
      "execution_count": null,
362
      "metadata": {
363
        "id": "CWfd1iC3iwmj"
364
      },
365
      "outputs": [],
366
      "source": []
367
    },
368
    {
369
      "cell_type": "code",
370
      "execution_count": null,
371
      "metadata": {
372
        "id": "vnLOEgyNOog9"
373
      },
374
      "outputs": [],
375
      "source": [
376
        "\"\"\"\n",
377
        "disp = load_img('{DIRECTORY}/r_4_disp_0029.tif', True) / 65535.0\n",
378
        "depth = 1.0 / disp[:, :, 0] - 1.0\n",
379
        "plt.imshow(depth)\n",
380
        "print(np.nanmin(depth), np.sqrt(4.0 ** 2 + 0.5 ** 2) - 1.0)\n",
381
        "\"\"\";"
382
      ]
383
    },
384
    {
385
      "cell_type": "code",
386
      "execution_count": null,
387
      "metadata": {
388
        "id": "cdnzd5bFi_4D"
389
      },
390
      "outputs": [],
391
      "source": [
392
        "Config = configs.Config()\n",
393
        "Config.dataset_loader = 'Blender'\n",
394
        "Config.near = 6\n",
395
        "Config.far = 2\n",
396
        "Config.factor = 1\n",
397
        "Config.disp_tiff = True\n",
398
        "\n",
399
        "# Force loading disparities and normals\n",
400
        "Config.compute_disp_metrics = True\n",
401
        "Config.compute_normal_metrics = True\n",
402
        "Config.semantic_dir = None\n",
403
        "import queue\n",
404
        "import jax\n",
405
        "import json\n",
406
        "import os\n",
407
        "\n",
408
        "LOCAL_COLMAP_DIR = '/tmp/colmap/'\n",
409
        "\n",
410
        "#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_occlusions_srgb_128x128'\n",
411
        "#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_occlusions_lambertian_srgb_128x128'; Config.disp_tiff = False\n",
412
        "#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_occlusions_linear_128x128'\n",
413
        "#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_occlusions_uniform_linear_128x128'\n",
414
        "#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_farther_occlusions_uniform_linear_128x128'\n",
415
        "#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_lowres_envmap_farfield_occluder_uniform_linear_128x128'\n",
416
        "#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_occlusions_lambertian_new_uniform_linear_128x128'\n",
417
        "data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_farfield_occlusions_lambertian_new_no_self_occ_uniform_linear_128x128'\n",
418
        "#materials_gt = {'albedo': jnp.ones_like(imgs_gt)*0.38}\n",
419
        "\n",
420
        "#data_loader = Blender('test', data_dir, Config)\n",
421
        "data_loader = Blender('occlusions', data_dir, Config)\n",
422
        "#data_loader._next_fn = data_loader._next_test"
423
      ]
424
    },
425
    {
426
      "cell_type": "code",
427
      "execution_count": null,
428
      "metadata": {
429
        "id": "sWVplUIuxrdZ"
430
      },
431
      "outputs": [],
432
      "source": [
433
        "imgs_gt = []\n",
434
        "normals_gt = []\n",
435
        "disps_gt = []\n",
436
        "alpha_gt = []\n",
437
        "rays_o_ = []\n",
438
        "rays_d_ = []\n",
439
        "\n",
440
        "N_cameras = data_loader.size\n",
441
        "for i in range(N_cameras):\n",
442
        "  batch = next(data_loader)\n",
443
        "  imgs_gt.append(batch.rgb)\n",
444
        "  normals_gt.append(batch.normals)\n",
445
        "  alpha_gt.append(batch.alphas)\n",
446
        "  disps_gt.append(batch.disps)\n",
447
        "  rays_o_.append(batch.rays.origins)\n",
448
        "  rays_d_.append(batch.rays.directions)\n"
449
      ]
450
    },
451
    {
452
      "cell_type": "code",
453
      "execution_count": null,
454
      "metadata": {
455
        "id": "JtDBdudY1vgH"
456
      },
457
      "outputs": [],
458
      "source": [
459
        "imgs_gt = jnp.stack(imgs_gt, axis=0).reshape(N_cameras, -1, 3)\n",
460
        "normals_gt = jnp.stack(normals_gt, axis=0).reshape(N_cameras, -1, 3)\n",
461
        "alpha_gt = jnp.float32(jnp.stack(alpha_gt, axis=0).reshape(N_cameras, -1, 1) \u003e 0.99)\n",
462
        "disps_gt = jnp.stack(disps_gt, axis=0)[..., :1].reshape(N_cameras, -1, 1)\n",
463
        "rays_o_vec = jnp.stack(rays_o_, axis=0).reshape(N_cameras, -1, 3)\n",
464
        "rays_d_vec = jnp.stack(rays_d_, axis=0).reshape(N_cameras, -1, 3)\n",
465
        "\n",
466
        "t_surface_gt = 1.0 / disps_gt - 1.0\n",
467
        "materials_gt = {'albedo': jnp.ones_like(imgs_gt)*0.15}\n"
468
      ]
469
    },
470
    {
471
      "cell_type": "code",
472
      "execution_count": null,
473
      "metadata": {
474
        "id": "R2DvKhauS1xl"
475
      },
476
      "outputs": [],
477
      "source": [
478
        "#pts = rays_o_vec + rays_d_vec * t_surface_gt\n",
479
        "\n",
480
        "#normals_gt = pts / jnp.linalg.norm(pts, axis=-1, keepdims=True)\n",
481
        "\n"
482
      ]
483
    },
484
    {
485
      "cell_type": "code",
486
      "execution_count": null,
487
      "metadata": {
488
        "id": "-8jQSjT2RRBx"
489
      },
490
      "outputs": [],
491
      "source": [
492
        "H, W = data_loader.images[0].shape[:2]\n",
493
        "print(f\"There are {imgs_gt.shape[0]} images of size {H}x{W}\")"
494
      ]
495
    },
496
    {
497
      "cell_type": "code",
498
      "execution_count": null,
499
      "metadata": {
500
        "id": "C_CDjVlQzlGw"
501
      },
502
      "outputs": [],
503
      "source": [
504
        "ind = 10\n",
505
        "plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
506
        "plt.figure()\n",
507
        "plt.imshow(normals_gt[ind].reshape(H, W, 3) * 0.5 + 0.5)\n",
508
        "plt.figure()\n",
509
        "plt.imshow(np.where(alpha_gt[ind] \u003c 0.99, np.nan, t_surface_gt[0]).reshape(H, W, 1))\n",
510
        "plt.figure()\n",
511
        "plt.imshow(alpha_gt[ind].reshape(H, W, 1))\n"
512
      ]
513
    },
514
    {
515
      "cell_type": "code",
516
      "execution_count": null,
517
      "metadata": {
518
        "id": "6Z-jwtNCms5O"
519
      },
520
      "outputs": [],
521
      "source": [
522
        "plt.imshow(imgs_gt[0].reshape(H, W, 3))\n",
523
        "plt.figure()\n",
524
        "plt.imshow(normals_gt[0].reshape(H, W, 3) * 0.5 + 0.5)\n",
525
        "plt.figure()\n",
526
        "plt.imshow(t_surface_gt[0].reshape(H, W, 1))\n",
527
        "plt.figure()\n",
528
        "plt.imshow(alpha_gt[0].reshape(H, W, 1))\n"
529
      ]
530
    },
531
    {
532
      "cell_type": "code",
533
      "execution_count": null,
534
      "metadata": {
535
        "id": "cR9X0LxU9QTY"
536
      },
537
      "outputs": [],
538
      "source": [
539
        "rays_d_vec.shape"
540
      ]
541
    },
542
    {
543
      "cell_type": "code",
544
      "execution_count": null,
545
      "metadata": {
546
        "id": "hJaeLucx9LCU"
547
      },
548
      "outputs": [],
549
      "source": [
550
        "\n",
551
        "rays_d_r = rays_d_vec.reshape(-1, H, W, 3)\n",
552
        "\n",
553
        "occluder_relative_size = 0.1 #0.03  # Ratio of the unit sphere occluder by the occluder\n",
554
        "th = 1.0 - 2 * occluder_relative_size\n",
555
        "# Make masks\n",
556
        "mask_shape = 'circle'\n",
557
        "if mask_shape == 'circle' or mask_shape == 'two_circles' or mask_shape == 'three_circles':\n",
558
        "  sdfs_gt = th - jnp.sum(-omega_xyz[None, :, :] * rays_d_r[:, H//2, W//2, :][:, None, :], axis=-1)\n",
559
        "if mask_shape == 'two_circles' or mask_shape == 'three_circles':\n",
560
        "  camera_dirs = rays_d_r[:, H//2, W//2, :]\n",
561
        "  up_dirs = rays_d_r[:, H//2+1, W//2, :] - camera_dirs\n",
562
        "  up_dirs = up_dirs / jnp.linalg.norm(up_dirs, axis=-1, keepdims=True)\n",
563
        "  rot_dirs = jnp.cross(up_dirs, camera_dirs, axis=-1) # direction of rotation is camera direction cross up direction\n",
564
        "\n",
565
        "  rotation_angle = jnp.pi / 7\n",
566
        "  dir2 = camera_dirs * jnp.cos(rotation_angle) + jnp.cross(rot_dirs, camera_dirs, axis=-1) * jnp.sin(rotation_angle) + rot_dirs * (rot_dirs * camera_dirs).sum(-1, keepdims=True) * (1 - jnp.cos(rotation_angle))\n",
567
        "  \n",
568
        "  sdfs_gt = jnp.minimum(sdfs_gt, th - jnp.sum(-omega_xyz[None, :, :] * dir2[:, None, :], axis=-1))  # Occluder is aligned with the camera\n",
569
        "if mask_shape == 'three_circles':\n",
570
        "  camera_dirs = rays_d_r[:, H//2, W//2, :]\n",
571
        "  up_dirs = rays_d_r[:, H//2+1, W//2, :] - camera_dirs\n",
572
        "  up_dirs = up_dirs / jnp.linalg.norm(up_dirs, axis=-1, keepdims=True)\n",
573
        "  rot_dirs = up_dirs # direction of rotation is up direction\n",
574
        "\n",
575
        "  rotation_angle = jnp.pi / 6\n",
576
        "  dir2 = camera_dirs * jnp.cos(rotation_angle) + jnp.cross(rot_dirs, camera_dirs, axis=-1) * jnp.sin(rotation_angle) + rot_dirs * (rot_dirs * camera_dirs).sum(-1, keepdims=True) * (1 - jnp.cos(rotation_angle))\n",
577
        "  sdfs_gt = jnp.minimum(sdfs_gt, th - jnp.sum(-omega_xyz[None, :, :] * dir2[:, None, :], axis=-1))  # Occluder is aligned with the camera\n",
578
        "\n",
579
        "sdfs_gt = sdfs_gt.reshape(-1, envmap_H, envmap_W)\n",
580
        "masks_gt = jnp.float32(sdfs_gt \u003e 0.0)\n",
581
        "\n"
582
      ]
583
    },
584
    {
585
      "cell_type": "code",
586
      "execution_count": null,
587
      "metadata": {
588
        "id": "0nRjSoAD9pBp"
589
      },
590
      "outputs": [],
591
      "source": [
592
        "shading = 'lambertian'\n",
593
        "exposure = 1.0\n",
594
        "\n",
595
        "# Render dataset\n",
596
        "num_devices = jax.local_device_count()\n",
597
        "imgs_gt = []\n",
598
        "\n",
599
        "render_gt_partial = functools.partial(render, shading=shading)\n",
600
        "for i in range(N_cameras // num_devices):\n",
601
        "  i0 = i * num_devices\n",
602
        "  i1 = i0 + num_devices\n",
603
        "\n",
604
        "  imgs_gt_ = jax.pmap(render_gt_partial, in_axes=(None, 0, 0, 0, 0, 0))(  # here pmap is faster\n",
605
        "      envmap_gt * exposure,\n",
606
        "      masks_gt[i0:i1],\n",
607
        "      jax.tree_map(lambda x: x[i0:i1], materials_gt),\n",
608
        "      normals_gt[i0:i1],\n",
609
        "      rays_d_vec.reshape(-1, H*W, 3)[i0:i1],\n",
610
        "      alpha_gt[i0:i1],\n",
611
        "      )\n",
612
        "  imgs_gt.append(imgs_gt_)\n",
613
        "\n",
614
        "imgs_gt = linear_to_srgb(jnp.concatenate(imgs_gt, axis=0))\n",
615
        "\n",
616
        "print(jnp.nanmax(imgs_gt))\n"
617
      ]
618
    },
619
    {
620
      "cell_type": "code",
621
      "execution_count": null,
622
      "metadata": {
623
        "id": "NASFYPonhri0"
624
      },
625
      "outputs": [],
626
      "source": [
627
        "ind = 50\n",
628
        "plt.figure()\n",
629
        "plt.imshow(imgs_gt[ind].reshape(H, W, 3))"
630
      ]
631
    },
632
    {
633
      "cell_type": "code",
634
      "execution_count": null,
635
      "metadata": {
636
        "id": "brlbzchGscOQ"
637
      },
638
      "outputs": [],
639
      "source": [
640
        "def mfbrdf_map(viewdir, normal, albedo, roughness, eps=1e-15):\n",
641
        "  half_vecs = omega_xyz + viewdir[None, :]\n",
642
        "  half_vecs /= (jnp.linalg.norm(half_vecs, axis=-1, keepdims=True) + eps)\n",
643
        "\n",
644
        "  n_dot_v = jnp.abs(jnp.sum(viewdir * normal)) + 1e-5\n",
645
        "  n_dot_l = jnp.maximum(jnp.sum(omega_xyz * normal[None, :], axis=-1), 0.0)\n",
646
        "  n_dot_h = jnp.maximum(jnp.sum(normal[None, :] * half_vecs, axis=-1), 0.0)\n",
647
        "  l_dot_h = jnp.maximum(jnp.sum(omega_xyz * half_vecs, axis=-1), 0.0)\n",
648
        "\n",
649
        "  print(n_dot_v.shape, n_dot_l.shape)\n",
650
        "\n",
651
        "  F_0 = 0.04\n",
652
        "  a = roughness**2\n",
653
        "\n",
654
        "  D = a / (jnp.pi * ((a - 1.0) * n_dot_h ** 2 + 1.)**2)\n",
655
        "  F = F_0 + (1. - F_0) * jnp.power(1. - l_dot_h, 5)\n",
656
        "  #V = 0.5 / ((n_dot_v * jnp.sqrt((-1. * n_dot_l * a + n_dot_l) * n_dot_l + a)) + (n_dot_l * jnp.sqrt((n_dot_v * (1 - a) * n_dot_v + a))))\n",
657
        "  V = 0.5 / ((n_dot_v * jnp.sqrt((-1. * n_dot_l * a + n_dot_l) * n_dot_l + a)) + (n_dot_l * jnp.sqrt((-1. * n_dot_v * a + n_dot_v) * n_dot_v + a)))\n",
658
        "  brdf = D * F * V\n",
659
        "\n",
660
        "  print(brdf.shape)\n",
661
        "  brdf = brdf + (1. - F) * albedo / jnp.pi\n",
662
        "  #brdf = jnp.reshape(brdf, [mapres[0], mapres[1], 3])\n",
663
        "  return brdf\n",
664
        "\n",
665
        "brdf = mfbrdf_map(jnp.array([0.0, 0.0, 1.0]), jnp.array([0.0, 1.0, 1.0])/jnp.sqrt(2), 0.7, 0.7)\n",
666
        "plt.imshow(brdf.reshape(envmap_H, envmap_W))\n",
667
        "plt.colorbar()"
668
      ]
669
    },
670
    {
671
      "cell_type": "code",
672
      "execution_count": null,
673
      "metadata": {
674
        "id": "9_J8f4Ow0SXG"
675
      },
676
      "outputs": [],
677
      "source": [
678
        "num_devices = jax.local_device_count()\n",
679
        "\n",
680
        "shading = 'lambertian'\n",
681
        "exposure = 1.0\n",
682
        "\n",
683
        "render_partial = functools.partial(render, shading=shading)\n",
684
        "\n",
685
        "gt_list = []\n",
686
        "#gt_list = ['materials']\n",
687
        "#gt_list = ['masks', 'materials']\n",
688
        "#gt_list = ['sdfs', 'materials']\n",
689
        "#gt_list = ['envmap', 'materials']\n",
690
        "#gt_list = ['envmap', 'masks']\n",
691
        "#gt_list = ['masks']\n",
692
        "#gt_list = ['masks', 'materials', 'envmap']\n",
693
        "\n",
694
        "class MLP(nn.Module):\n",
695
        "  features: Sequence[int]\n",
696
        "\n",
697
        "  @nn.compact\n",
698
        "  def __call__(self, x, y=None):\n",
699
        "    if y is not None:\n",
700
        "      x = jnp.concatenate([x, y], axis=-1)\n",
701
        "    for feat in self.features[:-1]:\n",
702
        "      x = nn.relu(nn.Dense(feat)(x))\n",
703
        "    x = nn.Dense(self.features[-1])(x)\n",
704
        "    return x\n",
705
        "\n",
706
        "def get_pyramid_params(rng, num_pyramids, height, width, pyramid_num_scales, pyramid_resize_scale, global_std, global_bias):\n",
707
        "  pyramid_params = []\n",
708
        "  for i in range(pyramid_num_scales):\n",
709
        "    gsh, gsw = [sz // pyramid_resize_scale ** i for sz in [height, width]]\n",
710
        "    key, rng = jax.random.split(rng)\n",
711
        "    features = jax.random.normal(key, (num_pyramids, gsh, gsw)) * global_std + global_bias\n",
712
        "    pyramid_params.append(features)\n",
713
        "  return pyramid_params\n",
714
        "\n",
715
        "\n",
716
        "def pyramids_to_imgs(pyramid_params, pyramid_mult, img_inds):\n",
717
        "  def pyramid_to_img(pyramid_params, pyramid_mult):\n",
718
        "    acc_val = pyramid_params[-1]\n",
719
        "    for i, curr_val in enumerate(pyramid_params[-2::-1], start=1):\n",
720
        "      # upsample\n",
721
        "      acc_val = jax.image.resize(acc_val * pyramid_mult, shape=curr_val.shape, method='linear')\n",
722
        "      \n",
723
        "      # accumulate\n",
724
        "      acc_val += curr_val\n",
725
        "    return acc_val\n",
726
        "\n",
727
        "  # Select all pyramid parameters at given indices\n",
728
        "  sub_pyramid_params = jax.tree_map(lambda t: t[img_inds], pyramid_params)\n",
729
        "  return jax.vmap(pyramid_to_img, in_axes=(0, None))(sub_pyramid_params, pyramid_mult)\n",
730
        "\n",
731
        "\n",
732
        "def grad_norm_spherical(dirs, grad):\n",
733
        "  \"\"\"\n",
734
        "  Compute gradient norm restricted to the sphere.\n",
735
        "  Assume dim 0 is elevation and 1 is azimuth.\n",
736
        "\n",
737
        "  dirs: (N, 3) array of directions on the sphere\n",
738
        "  grad: (N, 3) array of Cartesian gradients of points on dirs\n",
739
        "  \"\"\"\n",
740
        "\n",
741
        "  norm_spherical = grad - (dirs * grad).sum(axis=-1, keepdims=True) * dirs\n",
742
        "  return jnp.sqrt((norm_spherical ** 2).sum(-1) + 1e-5)\n",
743
        "\n",
744
        "\n",
745
        "\n",
746
        "rays_o_r = rays_o_vec.reshape(-1, H*W, 3)\n",
747
        "rays_d_r = rays_d_vec.reshape(-1, H*W, 3)\n",
748
        "\n",
749
        "\n",
750
        "\n",
751
        "append_identity = True\n",
752
        "def posenc(x, L_encoding):\n",
753
        "  if L_encoding \u003c= 0:\n",
754
        "    return x\n",
755
        "  else:\n",
756
        "    scales = 2**jnp.arange(L_encoding)\n",
757
        "    #shape = x.shape[:-1] + (-1,)\n",
758
        "    #scaled_x = jnp.reshape((x[..., None, :] * scales[:, None]), shape)\n",
759
        "\n",
760
        "    #four_feat = jnp.sin(\n",
761
        "    #    jnp.concatenate([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1))\n",
762
        "    shape = x.shape[:-1] + (-1,)\n",
763
        "    scaled_x = x[..., None, :] * scales[:, None] # [..., L, D]\n",
764
        "\n",
765
        "    four_feat = jnp.sin(\n",
766
        "        jnp.stack([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1)) # [..., L, D, 2]\n",
767
        "\n",
768
        "    four_feat = jnp.reshape(four_feat / scales[:, None, None], shape)\n",
769
        "    print(\"Using Lipschitz posenc\")\n",
770
        "    if append_identity:\n",
771
        "      return jnp.concatenate([x] + [four_feat], axis=-1)\n",
772
        "    else:\n",
773
        "      return four_feat\n",
774
        "\n",
775
        "\n",
776
        "def params_to_sdf(params_sdf, img_inds):\n",
777
        "  if sdf_representation == 'mlp':\n",
778
        "    sdf, sdf_grad = jax.vmap(jax.value_and_grad(lambda x, y: sdf_mlp.apply(params_sdf, x, y)[0]))(\n",
779
        "        posenc(omega_xyz[None, :, :], L_encoding_sdf).repeat(img_inds.shape[0], 0).reshape(-1, mlp_input_features),\n",
780
        "        rays_o_vec[img_inds, 0, 0, :][:, None, :].repeat(envmap_H * envmap_W, 1).reshape(-1, 3)\n",
781
        "        )\n",
782
        "  elif sdf_representation == 'grid':\n",
783
        "    sdf = params_sdf[img_inds[:, None], ...]\n",
784
        "  else:\n",
785
        "    sdf = pyramids_to_imgs(params_sdf, pyramid_mult, img_inds)\n",
786
        "  return sdf\n",
787
        "\n",
788
        "\n",
789
        "def sdf_to_mask(x, width, curve='sigmoid'):\n",
790
        "  if curve == 'sigmoid':\n",
791
        "    return jax.nn.sigmoid(x * width)\n",
792
        "  elif curve == 'laplace_cdf':\n",
793
        "    return 0.5 + 0.5 * jnp.sign(x) * (1.0 - jnp.exp(-jnp.abs(x) * width))\n",
794
        "  else:\n",
795
        "    raise NotImplementedError('Only sigmoid and laplace_cdf for now.')\n",
796
        "  \n",
797
        "def params_to_materials(params_materials, pts):\n",
798
        "  mlp_res = material_mlp.apply(params_materials, posenc(pts, L_encoding_materials))\n",
799
        "  materials = {}\n",
800
        "  materials['albedo'] = jax.nn.sigmoid(mlp_res[..., 0:3])\n",
801
        "  if shading in ['phong', 'blinnphong']:\n",
802
        "    #materials['specular_albedo'] = 20.0 * jax.nn.sigmoid(mlp_res[..., 3:6])\n",
803
        "    if is_dielectric:\n",
804
        "      materials['specular_albedo'] = jax.nn.softplus(mlp_res[..., 3:4])\n",
805
        "      materials['specular_exponent'] = jax.nn.softplus(mlp_res[..., 4:5])\n",
806
        "    else:\n",
807
        "      materials['specular_albedo'] = jax.nn.softplus(mlp_res[..., 3:6])\n",
808
        "      materials['specular_exponent'] = jax.nn.softplus(mlp_res[..., 6:7])\n",
809
        "\n",
810
        "  return materials\n",
811
        "\n",
812
        "R = jnp.array([[ 0.0, 1.0, 0.0],\n",
813
        "               [-1.0, 0.0, 0.0],\n",
814
        "               [ 0.0, 0.0, 1.0]])\n",
815
        "\n",
816
        "@jax.jit\n",
817
        "def get_loss(params_envmap, params_sdf, params_materials, gt, spatial_inds, img_inds, i, rng):\n",
818
        "  rng, key = jax.random.split(rng)\n",
819
        "\n",
820
        "  if 'envmap' in gt_list:\n",
821
        "    envmap = envmap_gt * exposure\n",
822
        "  else:\n",
823
        "    envmap = params_to_envmap(params_envmap)\n",
824
        "\n",
825
        "  normals = normals_gt[img_inds[:, None], spatial_inds, ...]\n",
826
        "  t_surface = t_surface_gt[img_inds[:, None], spatial_inds, ...]\n",
827
        "  alpha = alpha_gt[img_inds[:, None], spatial_inds, ...]\n",
828
        "  rays_d = rays_d_r[img_inds[:, None], spatial_inds, ...]\n",
829
        "  rays_o = rays_o_r[img_inds[:, None], spatial_inds, ...]\n",
830
        "\n",
831
        "  if 'materials' in gt_list:\n",
832
        "    materials = jax.tree_map(lambda x: x[img_inds[:, None], spatial_inds[img_inds], ...], materials_gt)\n",
833
        "  else:\n",
834
        "    pts = rays_o + rays_d * t_surface\n",
835
        "    materials = params_to_materials(params_materials, pts)\n",
836
        "\n",
837
        "  sdf = params_to_sdf(params_sdf, img_inds)\n",
838
        "  sdf = sdf.reshape(img_inds.shape[0], envmap_H * envmap_W)\n",
839
        "  #mask_width = 200.0 * (i / num_iters) + 10.0\n",
840
        "  #masks = sdf_to_mask(sdf.reshape(-1, envmap_H, envmap_W), mask_width, 'sigmoid')\n",
841
        "\n",
842
        "  sdf_curve = 'sigmoid'\n",
843
        "\n",
844
        "  if 'masks' in gt_list or 'sdfs' in gt_list:\n",
845
        "    if 'masks' in gt_list:\n",
846
        "      masks = masks_gt[img_inds]\n",
847
        "    else:\n",
848
        "      masks = sdf_to_mask(sdfs_gt[img_inds], mask_width, sdf_curve)\n",
849
        "    #masks = masks * 0.0 + 1.0\n",
850
        "    #print(\"Setting masks to 1!!!!!!!!!!!!!!!!!\")\n",
851
        "    eikonal_loss = 0.0\n",
852
        "    length_loss = 0.0\n",
853
        "    mask_area_loss = 0.0\n",
854
        "  else:\n",
855
        "    #mask_width = 10.0 #200.0 * (i / num_iters) + 10.0\n",
856
        "    #mask_width = 20.0 * (i / num_iters) + 10.0\n",
857
        "    #mask_width = 0.1\n",
858
        "    #mask_width = 10.0 #150.0 * (i / num_iters) + 6.0\n",
859
        "    #mask_width = 0.1 * jnp.exp(jnp.log(100) * i / num_iters)\n",
860
        "    if straight_through_mode == 'hard':\n",
861
        "      masks = sdf_to_mask(sdf, mask_width, sdf_curve)\n",
862
        "      masks = (masks + jax.lax.stop_gradient(jnp.float32(sdf \u003e 0.0) - masks))\n",
863
        "    elif straight_through_mode == 'soft':\n",
864
        "      print(\"I think this might actually be 'none' straight_through_mode instead of 'soft' because mask_width is 0.1\")\n",
865
        "      soft_masks = sdf_to_mask(sdf, 0.1, sdf_curve)\n",
866
        "      hard_masks = sdf_to_mask(sdf, mask_width, sdf_curve)\n",
867
        "      # Define masks with value of `hard_masks` but gradients of `soft_masks`\n",
868
        "      masks = (soft_masks + jax.lax.stop_gradient(hard_masks - soft_masks))\n",
869
        "    elif straight_through_mode == 'none':\n",
870
        "      masks = sdf_to_mask(sdf, mask_width, sdf_curve)\n",
871
        "\n",
872
        "\n",
873
        "\n",
874
        "    if sdf_representation == 'mlp':\n",
875
        "      # TODO: Only compute grad w.r.t. x, not w.r.t. other posenc components\n",
876
        "      sdf_grad = sdf_grad[..., :3].reshape(img_inds.shape[0], envmap_H * envmap_W, 3)\n",
877
        "      sdf_grad_norm = grad_norm_spherical(omega_xyz, sdf_grad.reshape(img_inds.shape[0], envmap_H * envmap_W, 3))\n",
878
        "      eikonal_loss = (jnp.sin(omega_theta) * (sdf_grad_norm - 1) ** 2).sum() * dtheta_dphi\n",
879
        "    else:\n",
880
        "      eikonal_loss = 0.0\n",
881
        "\n",
882
        "    # Compute entropy assuming mask = sigmoid(sdf)\n",
883
        "    entropy_loss = jax.nn.softplus(-sdf) + sdf * (1.0 - jax.nn.sigmoid(sdf))\n",
884
        "    mask_area_loss = 1.0 - masks  # Try to make the occluder as small as possible\n",
885
        "    #delta_sdf = jax.vmap(jax.grad(lambda x: sdf_to_mask(x, mask_width, sdf_curve)))(sdf.flatten()).reshape(img_inds.shape[0], envmap_H * envmap_W)\n",
886
        "\n",
887
        "    #length_loss = (jnp.sin(omega_theta) * sdf_grad_norm * delta_sdf).sum() * dtheta_dphi\n",
888
        "\n",
889
        "  res = jax.vmap(render_partial, in_axes=(None, 0, 0, 0, 0, 0))(\n",
890
        "      envmap,\n",
891
        "      masks.reshape(img_inds.shape[0], envmap_H, envmap_W),\n",
892
        "      materials,\n",
893
        "      normals,\n",
894
        "      rays_d,\n",
895
        "      alpha,\n",
896
        "      )\n",
897
        "\n",
898
        "  #diff = gt[img_inds[:, None], spatial_inds[img_inds], :] - res\n",
899
        "  diff = gt[img_inds[:, None], spatial_inds, :] - linear_to_srgb(res)\n",
900
        "  #loss_per_element = (diff ** 2).sum(-1)\n",
901
        "  #loss_per_element = jnp.abs(diff).sum(-1)\n",
902
        "  if False:\n",
903
        "    p = 2 - 1.5 * i / num_iters\n",
904
        "    data_loss = jnp.power((jnp.abs(diff + 1e-10) ** p).sum(), 1/p)\n",
905
        "    print(\"Using graduated nonconvexity in the loss\")\n",
906
        "  else:\n",
907
        "    data_loss = (jnp.abs(diff) ** 2).sum()\n",
908
        "    print(\"Using L2 loss\")\n",
909
        "\n",
910
        "  data_loss = data_loss / img_inds.shape[0] / spatial_inds.shape[-1]\n",
911
        "\n",
912
        "  loss = data_loss\n",
913
        "  loss += 0.1 * eikonal_loss / img_inds.shape[0]\n",
914
        "  loss += 1e-5 * (mask_area_loss * jnp.sin(omega_theta)).sum() * dtheta_dphi / img_inds.shape[0] / 4.0 / jnp.pi\n",
915
        "  #loss += 1e-1 * ((jnp.abs(params_envmap) ** 2) * ell.flatten()[:, None, None]).mean()\n",
916
        "  #loss += 0.01 * (entropy_loss * jnp.sin(omega_theta)).sum() * dtheta_dphi / img_inds.shape[0] / 4.0 / jnp.pi\n",
917
        "  #loss += 0.01 * length_loss / img_inds.shape[0]\n",
918
        "\n",
919
        "  # Environment map TV loss\n",
920
        "  loss += 1e-7 * (((envmap[:, 1:] - envmap[:, :-1]) ** 2).sum() + ((envmap[1:, :] - envmap[:-1, :]) ** 2).sum())  * dtheta_dphi / 4.0 / jnp.pi\n",
921
        "\n",
922
        "  return loss, (data_loss, eikonal_loss)\n",
923
        "\n",
924
        "def safe_exp(x):\n",
925
        "  return jnp.exp(jnp.minimum(x, 80.0))\n",
926
        "\n",
927
        "def tonemap_and_clip(x):\n",
928
        "  return np.clip(linear_to_srgb(x), 0.0, 1.0)\n",
929
        "\n",
930
        "def params_to_envmap(params_envmap):\n",
931
        "  #envmap = jax.nn.sigmoid(params_envmap)\n",
932
        "  #envmap = jax.nn.softplus(params_envmap)\n",
933
        "  if envmap_representation == 'SH':\n",
934
        "    #params_envmap = jnp.where(ell.flatten()[:, None, None] \u003c 5, params_envmap, 0.0)\n",
935
        "    envmap = jax.nn.softplus(jax.vmap(isht, in_axes=-1, out_axes=-1)(params_envmap))\n",
936
        "  else:\n",
937
        "    envmap = safe_exp(params_envmap)\n",
938
        "  return envmap\n",
939
        "\n",
940
        "\n",
941
        "@jax.jit\n",
942
        "def update_params(i, rng, state_envmap, state_sdf, state_materials, gt, spatial_inds, img_inds):\n",
943
        "  params_envmap = get_params_envmap(state_envmap)\n",
944
        "  params_sdf = get_params_sdf(state_sdf)\n",
945
        "  params_materials = get_params_materials(state_materials)\n",
946
        "\n",
947
        "  (loss, (data_loss, eikonal_loss)), g = jax.value_and_grad(get_loss, argnums=(0, 1, 2), has_aux=True)(params_envmap, params_sdf, params_materials,\n",
948
        "                                                                                                       gt, spatial_inds, img_inds, i, rng)\n",
949
        "\n",
950
        "  grad_envmap = jax.lax.pmean(g[0], axis_name='batch')\n",
951
        "  grad_sdf = jax.lax.pmean(g[1], axis_name='batch')\n",
952
        "  grad_materials = jax.lax.pmean(g[2], axis_name='batch')\n",
953
        "  eikonal_loss = jax.lax.pmean(eikonal_loss, axis_name='batch')\n",
954
        "  data_loss = jax.lax.pmean(data_loss, axis_name='batch')\n",
955
        "  loss = jax.lax.pmean(loss, axis_name='batch')\n",
956
        "\n",
957
        "  return update_envmap(i, grad_envmap, state_envmap), update_sdf(i, grad_sdf, state_sdf), update_materials(i, grad_materials, state_materials), loss, data_loss, eikonal_loss\n",
958
        "\n",
959
        "\n",
960
        "#slow_optimization_mode = ('masks' not in gt_list and 'sdfs' not in gt_list) or 'materials' not in gt_list\n",
961
        "#if slow_optimization_mode:\n",
962
        "#  print(\"Slow optimization\")\n",
963
        "#else:\n",
964
        "#  print(\"Fast optimization\")\n",
965
        "\n",
966
        "num_iters = 150000 #50000 if slow_optimization_mode else 10000\n",
967
        "#num_iters = 50000\n",
968
        "straight_through_mode = 'soft'\n",
969
        "assert straight_through_mode in ['none', 'hard', 'soft']\n",
970
        "\n",
971
        "# TODO:\n",
972
        "# 1. It looks like using straight-through on the masks (with constant width 10) makes them be a little off,\n",
973
        "#    but improves the envmap (making it a little noisier because of the bad masks). Why?\n",
974
        "# 2. \n",
975
        "\n",
976
        "envmap_representation = 'direct'\n",
977
        "if envmap_representation == 'SH':\n",
978
        "  #params_envmap = (jax.random.uniform(jax.random.PRNGKey(0), shape=(envmap_H, envmap_H, 3)) - 0.5) * 0.01\n",
979
        "  params_envmap = (jax.random.uniform(jax.random.PRNGKey(0), shape=(envmap_H, envmap_H, 3, 2)) - 0.5) / (1 + ell.flatten()[:, None, None, None]) * 0.01\n",
980
        "  params_envmap = params_envmap[..., 0] + 1j * params_envmap[..., 1]\n",
981
        "  init_lr_envmap = 0.0003 #if slow_optimization_mode else 0.001\n",
982
        "\n",
983
        "elif envmap_representation == 'direct':\n",
984
        "  params_envmap = (jax.random.uniform(jax.random.PRNGKey(0), shape=(envmap_H, envmap_W, 3)) - 0.5) * 0.1\n",
985
        "  #print(\"TODO: Get rid of this annoying -4.0\")\n",
986
        "\n",
987
        "  init_lr_envmap = 0.003 #if slow_optimization_mode else 0.03\n",
988
        "else:\n",
989
        "  raise ValueError('')\n",
990
        "  \n",
991
        "init_envmap, update_envmap, get_params_envmap = jax.experimental.optimizers.adam(init_lr_envmap)\n",
992
        "state_envmap = init_envmap(params_envmap)\n",
993
        "\n",
994
        "\n",
995
        "sdf_representation = 'pyramid'  # 'pyramid', 'mlp', 'pyramid'\n",
996
        "\n",
997
        "if sdf_representation == 'mlp':\n",
998
        "  L_encoding_sdf = 0 # 4\n",
999
        "  mlp_input_features = 3 + 6 * L_encoding_sdf\n",
1000
        "\n",
1001
        "  sdf_mlp = MLP([128]*4 + [1])\n",
1002
        "\n",
1003
        "  params_sdf = sdf_mlp.init(jax.random.PRNGKey(0),\n",
1004
        "                            np.zeros([1, mlp_input_features]),\n",
1005
        "                            np.zeros([1, 3]))\n",
1006
        "  init_lr_sdf = 0.0001\n",
1007
        "elif sdf_representation == 'pyramid':\n",
1008
        "  pyramid_num_scales = 5\n",
1009
        "  pyramid_resize_scale = 2\n",
1010
        "  pyramid_mult = 2.0\n",
1011
        "  global_std = 1.0 #0.1\n",
1012
        "  global_bias = 4.0\n",
1013
        "\n",
1014
        "  rng = jax.random.PRNGKey(0)\n",
1015
        "  params_sdf = get_pyramid_params(rng, N_cameras, envmap_H, envmap_W, pyramid_num_scales, pyramid_resize_scale, global_std, global_bias)\n",
1016
        "  init_lr_sdf = 0.1 #* 100\n",
1017
        "  mask_width = 0.1\n",
1018
        "\n",
1019
        "elif sdf_representation == 'grid':\n",
1020
        "  params_sdf = jax.random.normal(jax.random.PRNGKey(111), shape=(N_cameras, envmap_H, envmap_W))\n",
1021
        "  init_lr_sdf = 0.01\n",
1022
        "else:\n",
1023
        "  raise ValueError('')\n",
1024
        "\n",
1025
        "#for _ in range(20):\n",
1026
        "#  print(\"TODO: Use xmanager to optimize: learning rates, biases, global std for params_sdf in pyramid mode, number of iterations (longer!), etc.\")\n",
1027
        "#  print(\"TODO: Find out what happens if for masks we use the ground truth SDFs passed through a 0.1 sigmoid, instead of the GT masks. This is the best case scenario when using such a soft sigmoid!\")\n",
1028
        "#  print(\"TODO: Replace initialization as soft ~0.5ish masks and dark envmap with good init. of envmap and ~1 masks (no occluders). Currently things just go there anyway...\")\n",
1029
        "\n",
1030
        "init_sdf, update_sdf, get_params_sdf = jax.experimental.optimizers.adam(init_lr_sdf)\n",
1031
        "state_sdf = init_sdf(params_sdf)\n",
1032
        "\n",
1033
        "# Initialize material MLP\n",
1034
        "L_encoding_materials = 4\n",
1035
        "mlp_input_features = 3 + 6 * L_encoding_materials\n",
1036
        "\n",
1037
        "is_dielectric = True\n",
1038
        "num_components = 3  # 3 for diffuse\n",
1039
        "if shading in ['phong', 'blinnphong']:\n",
1040
        "  if is_dielectric:\n",
1041
        "    num_components += 2  # 1 for specular albedo, 1 for exponent\n",
1042
        "  else:\n",
1043
        "    num_components += 4  # 3 for specular albedo, 1 for exponent\n",
1044
        "material_mlp = MLP([128]*4 + [num_components])\n",
1045
        "\n",
1046
        "params_materials = material_mlp.init(jax.random.PRNGKey(0),\n",
1047
        "                          np.zeros([1, mlp_input_features]))\n",
1048
        "init_lr_materials = 0.003\n",
1049
        "init_materials, update_materials, get_params_materials = jax.experimental.optimizers.adam(init_lr_materials)\n",
1050
        "state_materials = init_materials(params_materials)\n",
1051
        "\n",
1052
        "\n",
1053
        "np_rng = np.random.default_rng(12345)\n",
1054
        "jax_rng = jax.random.PRNGKey(3948)\n",
1055
        "\n",
1056
        "spatial_batch_size = 64\n",
1057
        "\n",
1058
        "img_batch_size = N_cameras #1024\n",
1059
        "#img_batch_size = 64\n",
1060
        "losses = []\n",
1061
        "data_losses = []\n",
1062
        "eikonal_losses = []\n",
1063
        "envmap_psnrs = []\n",
1064
        "tonemapped_envmap_psnrs = []\n",
1065
        "envmaps = []\n",
1066
        "\n",
1067
        "t = 0.0\n",
1068
        "training_progress_bar = ProgressBar()\n",
1069
        "training_progress_bar.Publish()\n",
1070
        "\n",
1071
        "replicated_state_envmap = flax.jax_utils.replicate(state_envmap)\n",
1072
        "replicated_state_sdf = flax.jax_utils.replicate(state_sdf)\n",
1073
        "replicated_state_materials = flax.jax_utils.replicate(state_materials)\n",
1074
        "replicated_imgs_gt = flax.jax_utils.replicate(imgs_gt)\n",
1075
        "\n",
1076
        "\n",
1077
        "for iteration in range(num_iters):\n",
1078
        "  #t0 = time.time()\n",
1079
        "  # Generate B1 image indices\n",
1080
        "  img_inds = np_rng.choice(imgs_gt.shape[0], size=img_batch_size, replace=False) #* 0\n",
1081
        "  # Now generate B2 pixel indices for each image. The total batch size is B1 * B2.\n",
1082
        "\n",
1083
        "  keys = jax.random.split(jax_rng, num=img_batch_size+1)\n",
1084
        "  jax_rng, keys = keys[0], keys[1:]\n",
1085
        "  spatial_inds = jax.vmap(jax.random.choice, in_axes=(0, None, None, None, 0))(keys, H*W, (spatial_batch_size,), False, alpha_gt[img_inds, :, 0])\n",
1086
        "\n",
1087
        "  assert jnp.all(alpha_gt[img_inds[:, None], spatial_inds, :] \u003e 0.99)\n",
1088
        "\n",
1089
        "  replicated_state_envmap, replicated_state_sdf, replicated_state_materials, loss, data_loss, eikonal_loss = jax.pmap(update_params, in_axes=(None, None, 0, 0, 0, 0, 0, 0), axis_name='batch')(\n",
1090
        "      iteration,\n",
1091
        "      jax_rng,\n",
1092
        "      replicated_state_envmap,\n",
1093
        "      replicated_state_sdf,\n",
1094
        "      replicated_state_materials,\n",
1095
        "      replicated_imgs_gt,\n",
1096
        "      spatial_inds.reshape(num_devices, -1, spatial_batch_size),\n",
1097
        "      img_inds.reshape(num_devices, -1)\n",
1098
        "      )\n",
1099
        "\n",
1100
        "  if iteration % 100 == 0 or iteration == num_iters - 1:\n",
1101
        "    envmap = params_to_envmap(get_params_envmap(replicated_state_envmap)[0])\n",
1102
        "    envmaps.append(envmap)\n",
1103
        "    mse = (jnp.sin(omega_theta)[:, None] * (exposure * envmap_gt - envmap).reshape(-1, 3) ** 2).sum() * dtheta_dphi / 4.0 / jnp.pi / 3.0\n",
1104
        "    envmap_psnrs.append(mse_to_psnr(mse))\n",
1105
        "    mse = (jnp.sin(omega_theta)[:, None] * (tonemap_and_clip(exposure * envmap_gt) - tonemap_and_clip(envmap)).reshape(-1, 3) ** 2).sum() * dtheta_dphi / 4.0 / jnp.pi / 3.0\n",
1106
        "    tonemapped_envmap_psnrs.append(mse_to_psnr(mse))\n",
1107
        "\n",
1108
        "\n",
1109
        "  losses.append(loss[0])\n",
1110
        "  data_losses.append(data_loss[0])\n",
1111
        "  eikonal_losses.append(eikonal_loss[0])\n",
1112
        "\n",
1113
        "  if iteration in [500] or iteration % 1000 == 0 and iteration \u003e 0:\n",
1114
        "    clear_output(wait=True)\n",
1115
        "    training_progress_bar.Publish()\n",
1116
        "    training_progress_bar.SetProgress(100.0 * (iteration + 1) / num_iters)\n",
1117
        "\n",
1118
        "    plt.figure()\n",
1119
        "    plt.semilogy(data_losses)\n",
1120
        "    plt.semilogy(eikonal_losses)\n",
1121
        "\n",
1122
        "    # Plot envmaps\n",
1123
        "    plt.figure(figsize=[16, 8])\n",
1124
        "    plt.subplot(221)\n",
1125
        "    envmap = params_to_envmap(get_params_envmap(flax.jax_utils.unreplicate(replicated_state_envmap)))\n",
1126
        "    plt.imshow(envmap)\n",
1127
        "    plt.axis('off')\n",
1128
        "    plt.subplot(222)\n",
1129
        "    plt.imshow(exposure * envmap_gt)\n",
1130
        "    plt.axis('off')\n",
1131
        "    plt.subplot(223)\n",
1132
        "    plt.imshow(linear_to_srgb(params_to_envmap(get_params_envmap(flax.jax_utils.unreplicate(replicated_state_envmap)))))\n",
1133
        "    plt.axis('off')\n",
1134
        "    plt.subplot(224)\n",
1135
        "    plt.imshow(linear_to_srgb(exposure * envmap_gt))\n",
1136
        "    plt.axis('off')\n",
1137
        "\n",
1138
        "    # Plot PSNRs\n",
1139
        "    plt.figure(figsize=[8, 8])\n",
1140
        "    plt.subplot(121)\n",
1141
        "    plt.plot(np.linspace(0, iteration, len(envmap_psnrs)), envmap_psnrs)\n",
1142
        "    plt.subplot(122)\n",
1143
        "    plt.plot(np.linspace(0, iteration, len(envmap_psnrs)), tonemapped_envmap_psnrs)\n",
1144
        "\n",
1145
        "    # Plot masks\n",
1146
        "    num_rows = 6\n",
1147
        "    cameras_to_plot = [int(x) for x in np.linspace(0, N_cameras-1, num_rows)]\n",
1148
        "    sdf_params = get_params_sdf(flax.jax_utils.unreplicate(replicated_state_sdf))\n",
1149
        "    sdfs = params_to_sdf(sdf_params, jnp.array(cameras_to_plot))\n",
1150
        "\n",
1151
        "    plt.figure(figsize=[15, 12])\n",
1152
        "    for row, i in enumerate(cameras_to_plot):\n",
1153
        "      sdf = sdfs[row].reshape(envmap_H, envmap_W)\n",
1154
        "      for col, img_to_plot in enumerate([sdf, sdf_to_mask(sdf, mask_width), sdf \u003e 0]):\n",
1155
        "        plt.subplot(num_rows, 3, row * 3 + col + 1)\n",
1156
        "        if img_to_plot.shape[-1] != 3:\n",
1157
        "          if col == 0:\n",
1158
        "            plt.imshow(img_to_plot, cmap='gray')\n",
1159
        "          else:\n",
1160
        "            plt.imshow(img_to_plot, cmap='gray', vmin=0.0, vmax=1.0)\n",
1161
        "        else:\n",
1162
        "          plt.imshow(img_to_plot)\n",
1163
        "        plt.axis('off')\n",
1164
        "        plt.title(f'{i}')\n",
1165
        "\n",
1166
        "\n",
1167
        "    # Plot materials\n",
1168
        "    ind = 10\n",
1169
        "    pts = rays_o_r[ind] + t_surface_gt[ind] * rays_d_r[ind]\n",
1170
        "    materials = params_to_materials(get_params_materials(flax.jax_utils.unreplicate(replicated_state_materials)), pts)\n",
1171
        "    for k in materials.keys():\n",
1172
        "      plt.figure()\n",
1173
        "      plt.imshow(materials[k].reshape(H, W, -1))\n",
1174
        "      plt.axis('off')\n",
1175
        "    \n",
1176
        "    # Plot rendered image and GT image\n",
1177
        "    sdf = params_to_sdf(sdf_params, jnp.array([ind]))\n",
1178
        "    mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1179
        "    if 'envmap' in gt_list:\n",
1180
        "      envmap = exposure * envmap_gt\n",
1181
        "    if 'mask' in gt_list:\n",
1182
        "      mask = masks_gt[ind]\n",
1183
        "    if 'materials' in gt_list:\n",
1184
        "      materials = jax.tree_map(lambda x: x[ind], materials_gt)\n",
1185
        "    res = linear_to_srgb(render_partial(envmap, mask[0], materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))\n",
1186
        "    plt.figure()\n",
1187
        "    plt.subplot(121)\n",
1188
        "    plt.imshow(res)\n",
1189
        "    plt.axis('off')\n",
1190
        "    plt.subplot(122)\n",
1191
        "    plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1192
        "    plt.axis('off')\n",
1193
        "    plt.show()\n",
1194
        "\n"
1195
      ]
1196
    },
1197
    {
1198
      "cell_type": "code",
1199
      "execution_count": null,
1200
      "metadata": {
1201
        "id": "WyAaBUBblCpR"
1202
      },
1203
      "outputs": [],
1204
      "source": [
1205
        "plt.imshow(jnp.log(1e-5 + params_to_envmap(get_params_envmap(flax.jax_utils.unreplicate(replicated_state_envmap)))))\n",
1206
        "plt.figure()\n",
1207
        "plt.imshow(jnp.log(1e-5 + exposure * envmap_gt))\n"
1208
      ]
1209
    },
1210
    {
1211
      "cell_type": "code",
1212
      "execution_count": null,
1213
      "metadata": {
1214
        "id": "v81Xb-E5T_oR"
1215
      },
1216
      "outputs": [],
1217
      "source": [
1218
        "ind = 3\n",
1219
        "pts = rays_o_r[ind] + t_surface_gt[ind] * rays_d_r[ind]\n",
1220
        "materials = params_to_materials(get_params_materials(flax.jax_utils.unreplicate(replicated_state_materials)), pts)\n",
1221
        "\n",
1222
        "sdf = params_to_sdf(sdf_params, jnp.array([ind]))\n",
1223
        "mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1224
        "\n",
1225
        "res = linear_to_srgb(render_partial(envmap_gt, mask[0], {'albedo': materials['albedo'] * 0.0 + 0.15}, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))\n",
1226
        "plt.imshow(res)\n",
1227
        "plt.figure()\n",
1228
        "plt.imshow(imgs_gt[ind].reshape(H, W, 3))"
1229
      ]
1230
    },
1231
    {
1232
      "cell_type": "code",
1233
      "execution_count": null,
1234
      "metadata": {
1235
        "id": "LDMPMheWVG8G"
1236
      },
1237
      "outputs": [],
1238
      "source": [
1239
        "ind = 0\n",
1240
        "\n",
1241
        "sdf = params_to_sdf(sdf_params, jnp.array([ind]))\n",
1242
        "mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1243
        "\n",
1244
        "#envmap_top = np.zeros_like(envmap_gt)\n",
1245
        "#envmap_top[0, :, :] = 1000\n",
1246
        "#envmap_top[25, 49, :] = 10000\n",
1247
        "res = render_partial(envmap_gt, mask[0]*0.0+1.0,\n",
1248
        "                                    {'albedo': materials['albedo'] * 0.0 + 0.15},\n",
1249
        "                                    normals_gt[ind], -rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3)\n",
1250
        "plt.imshow(linear_to_srgb(res), interpolation='nearest')\n",
1251
        "plt.figure()\n",
1252
        "plt.imshow(imgs_gt[ind].reshape(H, W, 3))"
1253
      ]
1254
    },
1255
    {
1256
      "cell_type": "code",
1257
      "execution_count": null,
1258
      "metadata": {
1259
        "id": "RTlmYw1CtsWn"
1260
      },
1261
      "outputs": [],
1262
      "source": [
1263
        "omega_phi.reshape(envmap_H, envmap_W)[0, 49]"
1264
      ]
1265
    },
1266
    {
1267
      "cell_type": "code",
1268
      "execution_count": null,
1269
      "metadata": {
1270
        "id": "8dWB1SHRroGu"
1271
      },
1272
      "outputs": [],
1273
      "source": []
1274
    },
1275
    {
1276
      "cell_type": "code",
1277
      "execution_count": null,
1278
      "metadata": {
1279
        "id": "-CvSkOxhroXv"
1280
      },
1281
      "outputs": [],
1282
      "source": [
1283
        "plt.imshow(omega_x.reshape(envmap_H, envmap_W))"
1284
      ]
1285
    },
1286
    {
1287
      "cell_type": "code",
1288
      "execution_count": null,
1289
      "metadata": {
1290
        "id": "Nzezbd06rodC"
1291
      },
1292
      "outputs": [],
1293
      "source": []
1294
    },
1295
    {
1296
      "cell_type": "code",
1297
      "execution_count": null,
1298
      "metadata": {
1299
        "id": "CG4EL5ySpm2U"
1300
      },
1301
      "outputs": [],
1302
      "source": [
1303
        "plt.plot(res[64])\n",
1304
        "print(jnp.nanmin(res) / jnp.nanmax(res))\n",
1305
        "print(jnp.nanmax(res) * 0.27823895, res[18, 64, 0])"
1306
      ]
1307
    },
1308
    {
1309
      "cell_type": "code",
1310
      "execution_count": null,
1311
      "metadata": {
1312
        "id": "H08mIu3Tj9B6"
1313
      },
1314
      "outputs": [],
1315
      "source": [
1316
        "print(alpha_gt[0].reshape(H, W)[18, :].sum())\n",
1317
        "print(alpha_gt[0].reshape(H, W)[18, 64])"
1318
      ]
1319
    },
1320
    {
1321
      "cell_type": "code",
1322
      "execution_count": null,
1323
      "metadata": {
1324
        "id": "xVE4XVbBj5fz"
1325
      },
1326
      "outputs": [],
1327
      "source": [
1328
        "normals_gt[0].reshape(H, W, 3)[18, 64, :]"
1329
      ]
1330
    },
1331
    {
1332
      "cell_type": "code",
1333
      "execution_count": null,
1334
      "metadata": {
1335
        "id": "Kijvl705lOPr"
1336
      },
1337
      "outputs": [],
1338
      "source": [
1339
        "def jon(x, eps=1e-7):\n",
1340
        "  denom_sq = x ** 2\n",
1341
        "  normal = x / jnp.sqrt(jnp.maximum(denom_sq, eps))\n",
1342
        "  return jnp.where(denom_sq \u003c eps, jnp.zeros_like(normal), normal)\n",
1343
        "\n",
1344
        "def dor(x, eps=1e-7):\n",
1345
        "  return x / jnp.sqrt(jnp.maximum(x**2, eps))\n",
1346
        "\n",
1347
        "\n",
1348
        "x = jnp.linspace(-0.001, 0.001, 10000)\n",
1349
        "plt.plot(x, jon(x), x, dor(x))"
1350
      ]
1351
    },
1352
    {
1353
      "cell_type": "code",
1354
      "execution_count": null,
1355
      "metadata": {
1356
        "id": "OMCb95CSPk7R"
1357
      },
1358
      "outputs": [],
1359
      "source": [
1360
        "ind = 0\n",
1361
        "\n",
1362
        "print(rays_d_vec[ind].reshape(H, W, 3)[0, W//2, :])  # Top is x\n",
1363
        "print(rays_d_vec[ind].reshape(H, W, 3)[H//2, 1, :])  # Left is y\n",
1364
        "\n",
1365
        "n = (normals_gt[ind].reshape(H, W, 3) @ R.T) * 0.5 + 0.5\n",
1366
        "plt.imshow(n)"
1367
      ]
1368
    },
1369
    {
1370
      "cell_type": "code",
1371
      "execution_count": null,
1372
      "metadata": {
1373
        "id": "8jAr6Lcp9dl8"
1374
      },
1375
      "outputs": [],
1376
      "source": []
1377
    },
1378
    {
1379
      "cell_type": "code",
1380
      "execution_count": null,
1381
      "metadata": {
1382
        "id": "KNyuQ0oryIWO"
1383
      },
1384
      "outputs": [],
1385
      "source": [
1386
        "plt.imshow(envmap_gt)\n",
1387
        "plt.figure()\n",
1388
        "sdf = params_to_sdf(sdf_params, jnp.array([0]))\n",
1389
        "mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1390
        "\n",
1391
        "plt.imshow(mask[0])"
1392
      ]
1393
    },
1394
    {
1395
      "cell_type": "code",
1396
      "execution_count": null,
1397
      "metadata": {
1398
        "id": "tvbSIOQk9OX7"
1399
      },
1400
      "outputs": [],
1401
      "source": []
1402
    },
1403
    {
1404
      "cell_type": "code",
1405
      "execution_count": null,
1406
      "metadata": {
1407
        "id": "6Tx54EIFxFPZ"
1408
      },
1409
      "outputs": [],
1410
      "source": [
1411
        "plt.imshow(masks_gt[-1])"
1412
      ]
1413
    },
1414
    {
1415
      "cell_type": "code",
1416
      "execution_count": null,
1417
      "metadata": {
1418
        "id": "gwCzHL7CykR3"
1419
      },
1420
      "outputs": [],
1421
      "source": [
1422
        "#plt.figure(figsize=[12, 12])\n",
1423
        "#plt.plot(masks_gt[-1][34, :], '.')\n",
1424
        "plt.plot((1-masks_gt[-1]).sum(1), '.')"
1425
      ]
1426
    },
1427
    {
1428
      "cell_type": "code",
1429
      "execution_count": null,
1430
      "metadata": {
1431
        "id": "hyD_ca65q9aK"
1432
      },
1433
      "outputs": [],
1434
      "source": [
1435
        "# Plot rendered image and GT image\n",
1436
        "sdf = params_to_sdf(sdf_params, jnp.array([ind]))\n",
1437
        "mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1438
        "if 'envmap' in gt_list:\n",
1439
        "  envmap = exposure * envmap_gt\n",
1440
        "if 'mask' in gt_list:\n",
1441
        "  mask = masks_gt[ind]\n",
1442
        "if 'materials' in gt_list:\n",
1443
        "  materials = jax.tree_map(lambda x: x[ind], materials_gt)\n",
1444
        "res = linear_to_srgb(render_partial(envmap, mask[0], materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))\n",
1445
        "plt.figure()\n",
1446
        "plt.subplot(121)\n",
1447
        "plt.imshow(res)\n",
1448
        "plt.axis('off')\n",
1449
        "plt.subplot(122)\n",
1450
        "plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1451
        "plt.axis('off')\n",
1452
        "plt.show()\n",
1453
        "\n",
1454
        "diff = imgs_gt[ind].reshape(H, W, 3) - res\n",
1455
        "#loss_per_element = (diff ** 2).sum(-1)\n",
1456
        "#loss_per_element = jnp.abs(diff).sum(-1)\n",
1457
        "if False:\n",
1458
        "  p = 2 - 1.5 * i / num_iters\n",
1459
        "  data_loss = jnp.power((jnp.abs(diff + 1e-10) ** p).sum(), 1/p)\n",
1460
        "  print(\"Using graduated nonconvexity in the loss\")\n",
1461
        "else:\n",
1462
        "  data_loss = (jnp.abs(diff) ** 2).sum()\n",
1463
        "  print(\"Using L2 loss\")\n",
1464
        "\n",
1465
        "data_loss = data_loss / 1 / 128 / 128\n",
1466
        "print(data_loss)"
1467
      ]
1468
    },
1469
    {
1470
      "cell_type": "code",
1471
      "execution_count": null,
1472
      "metadata": {
1473
        "id": "iZii_KyJrN17"
1474
      },
1475
      "outputs": [],
1476
      "source": [
1477
        "diff = res - imgs_gt[ind].reshape(H, W, 3)\n",
1478
        "diff.min(), diff.max()"
1479
      ]
1480
    },
1481
    {
1482
      "cell_type": "code",
1483
      "execution_count": null,
1484
      "metadata": {
1485
        "id": "5G4_YPmorCnm"
1486
      },
1487
      "outputs": [],
1488
      "source": [
1489
        "plt.imshow(mask[0])"
1490
      ]
1491
    },
1492
    {
1493
      "cell_type": "code",
1494
      "execution_count": null,
1495
      "metadata": {
1496
        "id": "URJTcUOo9rdb"
1497
      },
1498
      "outputs": [],
1499
      "source": []
1500
    },
1501
    {
1502
      "cell_type": "code",
1503
      "execution_count": null,
1504
      "metadata": {
1505
        "id": "bhkb4AWsGzmn"
1506
      },
1507
      "outputs": [],
1508
      "source": [
1509
        " media.show_video(envmaps, height=300)"
1510
      ]
1511
    },
1512
    {
1513
      "cell_type": "code",
1514
      "execution_count": null,
1515
      "metadata": {
1516
        "id": "jsUCKGrFbyt7"
1517
      },
1518
      "outputs": [],
1519
      "source": [
1520
        "num_rows = 6\n",
1521
        "cameras_to_plot = [int(x) for x in np.linspace(0, N_cameras-1, num_rows)]\n",
1522
        "sdf_params = get_params_sdf(flax.jax_utils.unreplicate(replicated_state_sdf))\n",
1523
        "sdfs = params_to_sdf(sdf_params, jnp.array(cameras_to_plot))\n",
1524
        "plt.imshow(jnp.float32(sdfs[2] \u003e 0) - jnp.float32(sdfs[1] \u003e 0))"
1525
      ]
1526
    },
1527
    {
1528
      "cell_type": "code",
1529
      "execution_count": null,
1530
      "metadata": {
1531
        "id": "p2lbNspacFqv"
1532
      },
1533
      "outputs": [],
1534
      "source": [
1535
        "plt.imshow(sdfs[3] \u003e 0)"
1536
      ]
1537
    },
1538
    {
1539
      "cell_type": "code",
1540
      "execution_count": null,
1541
      "metadata": {
1542
        "id": "_cNuUtm2Tz3L"
1543
      },
1544
      "outputs": [],
1545
      "source": [
1546
        "jnp.where(alpha_gt == 1.0, imgs_gt, 0.0).max()"
1547
      ]
1548
    },
1549
    {
1550
      "cell_type": "code",
1551
      "execution_count": null,
1552
      "metadata": {
1553
        "id": "YYEJmEDP4KKk"
1554
      },
1555
      "outputs": [],
1556
      "source": [
1557
        "plt.imshow(np.float32(imgs_gt[5].reshape(H, W, 3) == 1))"
1558
      ]
1559
    },
1560
    {
1561
      "cell_type": "code",
1562
      "execution_count": null,
1563
      "metadata": {
1564
        "id": "WKPz2rHYKZHV"
1565
      },
1566
      "outputs": [],
1567
      "source": [
1568
        "media.show_video(envmaps, height=200)"
1569
      ]
1570
    },
1571
    {
1572
      "cell_type": "code",
1573
      "execution_count": null,
1574
      "metadata": {
1575
        "id": "gc21N4x__zqC"
1576
      },
1577
      "outputs": [],
1578
      "source": [
1579
        "ind = 11\n",
1580
        "sdf = params_to_sdf(sdf_params, jnp.array([ind]))\n",
1581
        "mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)\n",
1582
        "envmap = envmap_gt\n",
1583
        "res = linear_to_srgb(render_partial(envmap, mask[0], materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))\n",
1584
        "plt.figure()\n",
1585
        "plt.subplot(121)\n",
1586
        "plt.imshow(res)\n",
1587
        "plt.axis('off')\n",
1588
        "plt.subplot(122)\n",
1589
        "plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1590
        "plt.axis('off')\n",
1591
        "plt.show()\n"
1592
      ]
1593
    },
1594
    {
1595
      "cell_type": "code",
1596
      "execution_count": null,
1597
      "metadata": {
1598
        "id": "XLeHIuktbhJN"
1599
      },
1600
      "outputs": [],
1601
      "source": [
1602
        "omega_xyz.shape"
1603
      ]
1604
    },
1605
    {
1606
      "cell_type": "code",
1607
      "execution_count": null,
1608
      "metadata": {
1609
        "id": "BESEhvA8iJ6h"
1610
      },
1611
      "outputs": [],
1612
      "source": [
1613
        "ind\n",
1614
        "\n",
1615
        "materials = jax.tree_map(lambda x: x[ind], materials_gt)\n",
1616
        "envmap = envmap_gt\n",
1617
        "\n",
1618
        "mask = jnp.sum(-omega_xyz * rays_d_vec[ind, H//2*W+W//2, :][None, :], axis=-1) \u003c 0.76649692  # Occluder is aligned with the camera\n",
1619
        "\n",
1620
        "res = linear_to_srgb(render_partial(envmap, mask.reshape(envmap_H, envmap_W), materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))\n",
1621
        "#diff = res - imgs_gt[ind].reshape(H, W, 3)\n",
1622
        "plt.figure()\n",
1623
        "plt.imshow(res)\n",
1624
        "\"\"\"\n",
1625
        "plt.figure()\n",
1626
        "plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1627
        "plt.figure()\n",
1628
        "err = np.abs(diff).sum(-1)\n",
1629
        "plt.imshow(err, cmap='gray')\n",
1630
        "plt.colorbar()\n",
1631
        "\"\"\";"
1632
      ]
1633
    },
1634
    {
1635
      "cell_type": "code",
1636
      "execution_count": null,
1637
      "metadata": {
1638
        "id": "l2d7rmUmcRGH"
1639
      },
1640
      "outputs": [],
1641
      "source": [
1642
        "plt.imshow(normals_gt[ind].reshape(H, W, 3))"
1643
      ]
1644
    },
1645
    {
1646
      "cell_type": "code",
1647
      "execution_count": null,
1648
      "metadata": {
1649
        "id": "PTMdZXyboxE-"
1650
      },
1651
      "outputs": [],
1652
      "source": [
1653
        "envmap = envmap_gt\n",
1654
        "\n",
1655
        "rows = []\n",
1656
        "for i in range(H):\n",
1657
        "  inds = np.arange(W) + i * W\n",
1658
        "  materials = jax.tree_map(lambda x: x[ind, inds], materials_gt)\n",
1659
        "  res = linear_to_srgb(render_partial(envmap, envmap[..., 0]*0.0+1.0, materials, normals_gt[ind, inds], rays_d_r[ind, inds], alpha_gt[ind, inds]))\n",
1660
        "  rows.append(res)\n",
1661
        "\n",
1662
        "res = jnp.concatenate(rows, axis=0).reshape(H, W, 3)\n",
1663
        "diff = res - imgs_gt[ind].reshape(H, W, 3)\n",
1664
        "plt.figure()\n",
1665
        "plt.imshow(res)\n",
1666
        "plt.figure()\n",
1667
        "plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1668
        "plt.figure()\n",
1669
        "err = np.abs(diff).sum(-1)\n",
1670
        "plt.imshow(err, cmap='gray')\n",
1671
        "plt.colorbar()\n"
1672
      ]
1673
    },
1674
    {
1675
      "cell_type": "code",
1676
      "execution_count": null,
1677
      "metadata": {
1678
        "id": "98uykDQhWc1b"
1679
      },
1680
      "outputs": [],
1681
      "source": [
1682
        "envmap = envmap_linear\n",
1683
        "\n",
1684
        "rows = []\n",
1685
        "for i in range(H):\n",
1686
        "  cols = []\n",
1687
        "  for j in range(W):\n",
1688
        "    inds = np.arange(1) + i * W + j\n",
1689
        "    materials = jax.tree_map(lambda x: x[ind, inds], materials_gt)\n",
1690
        "    res = linear_to_srgb(render_partial(envmap, envmap[..., 0]*0.0+1.0, materials, normals_gt[ind, inds], rays_d_r[ind, inds], alpha_gt[ind, inds]))\n",
1691
        "    cols.append(res)\n",
1692
        "  row = jnp.concatenate(cols, axis=0)\n",
1693
        "  rows.append(row)\n",
1694
        "\n",
1695
        "res = jnp.concatenate(rows, axis=0).reshape(H, W, 3)\n",
1696
        "diff = res - imgs_gt[ind].reshape(H, W, 3)\n",
1697
        "plt.figure()\n",
1698
        "plt.imshow(res)\n",
1699
        "plt.figure()\n",
1700
        "plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1701
        "plt.figure()\n",
1702
        "err = np.abs(diff).sum(-1)\n",
1703
        "plt.imshow(err, cmap='gray')\n",
1704
        "plt.colorbar()\n"
1705
      ]
1706
    },
1707
    {
1708
      "cell_type": "code",
1709
      "execution_count": null,
1710
      "metadata": {
1711
        "id": "b6Mo5B8ba3Il"
1712
      },
1713
      "outputs": [],
1714
      "source": [
1715
        "res = jnp.concatenate(rows, axis=0).reshape(-1, W, 3)\n",
1716
        "#diff = res - imgs_gt[ind].reshape(H, W, 3)\n",
1717
        "plt.figure()\n",
1718
        "plt.imshow(res)\n",
1719
        "plt.figure()\n",
1720
        "plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1721
        "plt.figure()\n",
1722
        "err = np.abs(diff).sum(-1)\n",
1723
        "plt.imshow(err, cmap='gray')\n",
1724
        "plt.colorbar()\n"
1725
      ]
1726
    },
1727
    {
1728
      "cell_type": "code",
1729
      "execution_count": null,
1730
      "metadata": {
1731
        "id": "vKdgW6WSjFmF"
1732
      },
1733
      "outputs": [],
1734
      "source": [
1735
        "materials['albedo'].shape"
1736
      ]
1737
    },
1738
    {
1739
      "cell_type": "code",
1740
      "execution_count": null,
1741
      "metadata": {
1742
        "id": "lf7lQmu2i-o-"
1743
      },
1744
      "outputs": [],
1745
      "source": [
1746
        "plt.plot(err[:, 64])"
1747
      ]
1748
    },
1749
    {
1750
      "cell_type": "code",
1751
      "execution_count": null,
1752
      "metadata": {
1753
        "id": "ixZYlvq8jPT1"
1754
      },
1755
      "outputs": [],
1756
      "source": [
1757
        "def interp2d(grids, inds):\n",
1758
        "  \"\"\"\n",
1759
        "  grids is [H, W, d]\n",
1760
        "  inds is [2, ...], with the 0th dim being elevation and 1st azimuth\n",
1761
        "  \"\"\"\n",
1762
        "\n",
1763
        "  results = []\n",
1764
        "  for grid in [grids[:, :, d] for d in range(grids.shape[-1])]:\n",
1765
        "    res = jax.scipy.ndimage.map_coordinates(grid, inds, order=1, mode='wrap')\n",
1766
        "    results.append(res)\n",
1767
        "\n",
1768
        "  return jnp.stack(results, axis=-1)\n",
1769
        "\n",
1770
        "\n",
1771
        "def render_pixel_mirror(refdir, envmap, mask):\n",
1772
        "  x, y, z = refdir\n",
1773
        "  theta = jnp.arctan2(jnp.sqrt(x ** 2 + y ** 2), z)\n",
1774
        "  phi   = jnp.arctan2(y, x)\n",
1775
        "  # Quantize to get index\n",
1776
        "  theta_ind = jnp.floor(envmap_H * theta / jnp.pi).astype(jnp.int32)\n",
1777
        "  phi_ind = jnp.round(envmap_W * phi / 2.0 / jnp.pi).astype(jnp.int32)\n",
1778
        "  return (envmap * mask[:, :, None])[theta_ind, phi_ind]\n",
1779
        "  #return interp2d(envmap * mask[:, :, None], [theta*(envmap_H-1)/jnp.pi, phi*(envmap_W-1)/2/jnp.pi])\n",
1780
        "\n",
1781
        "def render_mirror(envmap, mask, normals, rays_d, alpha, oxyz, rad, shape='sphere'):\n",
1782
        "  \"\"\"\n",
1783
        "  envmap:     shape [h, w, 3]\n",
1784
        "  mask:       shape [h, w]\n",
1785
        "  materials:  dictionary with entries of shape [N, 3]\n",
1786
        "  normals:    shape [N, 3]\n",
1787
        "  rays_d:     shape [N, 3]\n",
1788
        "  alpha:      shape [N, 1]\n",
1789
        "  oxyz:       shape [1, 3]  (shape center)\n",
1790
        "  rad:        float         (shape radius)\n",
1791
        "  \n",
1792
        "  output: rendered colors, shape [N, 3]\n",
1793
        "  \"\"\"\n",
1794
        "  d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)\n",
1795
        "  rays_d_norm = rays_d / jnp.sqrt(d_norm_sq + 1e-10)  # [N, 3]\n",
1796
        "\n",
1797
        "  refdirs = 2.0 * (normals * rays_d_norm).sum(-1, keepdims=True) * normals - rays_d_norm   # [N, 3]\n",
1798
        "  print(refdirs.shape, envmap.shape, mask.shape)\n",
1799
        "  colors = jax.vmap(render_pixel_mirror, in_axes=(0, None, None))(refdirs, envmap, mask)     \n",
1800
        "  \n",
1801
        "  return colors * alpha\n",
1802
        "\n",
1803
        "img_ind = 0\n",
1804
        "d = -rays_d_vec[img_ind] #* jnp.array([1.0, -1.0, 1.0])\n",
1805
        "R = jnp.array([[ 0.0, 1.0, 0.0],\n",
1806
        "               [-1.0, 0.0, 0.0],\n",
1807
        "               [ 0.0, 0.0, 1.0]])\n",
1808
        "n = normals_gt[img_ind] @ R\n",
1809
        "\n",
1810
        "res = render_mirror(jnp.fliplr(envmap_gt), envmap_gt[:, :, 0]*0.0+1.0,\n",
1811
        "                    n, d, alpha_gt[img_ind], jnp.zeros((3,)), 1.0)\n",
1812
        "res_srgb = linear_to_srgb(res.reshape(H, W, 3))\n",
1813
        "plt.figure(figsize=[12, 12])\n",
1814
        "plt.imshow(res_srgb, interpolation='nearest')\n",
1815
        "plt.axis('off')"
1816
      ]
1817
    },
1818
    {
1819
      "cell_type": "code",
1820
      "execution_count": null,
1821
      "metadata": {
1822
        "id": "6LjFQh_tcFiA"
1823
      },
1824
      "outputs": [],
1825
      "source": [
1826
        "plt.imshow(normals_gt[0].reshape(H, W, 3) * 0.5 + 0.5)"
1827
      ]
1828
    },
1829
    {
1830
      "cell_type": "code",
1831
      "execution_count": null,
1832
      "metadata": {
1833
        "id": "9_bzZTokbi0r"
1834
      },
1835
      "outputs": [],
1836
      "source": [
1837
        "rays_d_vec[0].reshape(H, W, 3)[0, W//2, :]"
1838
      ]
1839
    },
1840
    {
1841
      "cell_type": "code",
1842
      "execution_count": null,
1843
      "metadata": {
1844
        "id": "RDksiHZkh8lF"
1845
      },
1846
      "outputs": [],
1847
      "source": [
1848
        "plt.imshow(envmap_gt)"
1849
      ]
1850
    },
1851
    {
1852
      "cell_type": "code",
1853
      "execution_count": null,
1854
      "metadata": {
1855
        "id": "lXdTJL1Ozw1Y"
1856
      },
1857
      "outputs": [],
1858
      "source": [
1859
        "with open('{DIRECTORY}/r_0_lamb.png', 'rb') as f:\n",
1860
        "  img_lamb = np.array(Image.open(f))[:, :, :3] / 255.0\n"
1861
      ]
1862
    },
1863
    {
1864
      "cell_type": "code",
1865
      "execution_count": null,
1866
      "metadata": {
1867
        "id": "raQb_af6vO1M"
1868
      },
1869
      "outputs": [],
1870
      "source": [
1871
        "plt.imshow(alpha_gt[ind].reshape(H, W))"
1872
      ]
1873
    },
1874
    {
1875
      "cell_type": "code",
1876
      "execution_count": null,
1877
      "metadata": {
1878
        "id": "7F65WtLmReQL"
1879
      },
1880
      "outputs": [],
1881
      "source": [
1882
        "ind = 20\n",
1883
        "\n",
1884
        "with open(f'{DIRECTORY}/r_{ind}_larger.png', 'rb') as f:\n",
1885
        "  res_gt = np.float32(Image.open(f)) / 255.0\n",
1886
        "  alpha = res_gt[:, :, 3:]\n",
1887
        "  res_gt = res_gt[:, :, :3] * alpha\n",
1888
        "\n",
1889
        "#elevation = omega_theta.reshape(envmap_H, envmap_W) / jnp.pi * (envmap_H - 1 + 1)\n",
1890
        "#azimuth = omega_phi.reshape(envmap_H, envmap_W) / (2.0 * jnp.pi) * (envmap_W - 1) + 0.5 # + 4.0\n",
1891
        "#inds = jnp.stack([jnp.mod(elevation, envmap_H), jnp.mod(azimuth, envmap_W)], axis=0)\n",
1892
        "#envmap = interp2d(envmap_gt, inds)\n",
1893
        "#print(envmap.shape)\n",
1894
        "envmap = envmap_gt\n",
1895
        "\n",
1896
        "rows = []\n",
1897
        "for i in range(H):\n",
1898
        "  inds = np.arange(W) + i * W\n",
1899
        "  materials = jax.tree_map(lambda x: x[ind, inds], materials_gt)\n",
1900
        "  res = linear_to_srgb(render_partial(envmap, envmap[..., 0]*0.0+1.0, {'albedo': jnp.ones((W, 3))*0.15}, normals_gt[ind, inds], rays_d_r[ind, inds], alpha.reshape(-1, 1)[inds]))\n",
1901
        "  rows.append(res)\n",
1902
        "\n",
1903
        "res = jnp.concatenate(rows, axis=0).reshape(H, W, 3)\n",
1904
        "plt.imshow(res)\n",
1905
        "plt.figure()\n",
1906
        "\n",
1907
        "plt.imshow(res_gt)\n",
1908
        "\n",
1909
        "plt.figure()\n",
1910
        "diff = res - res_gt\n",
1911
        "plt.imshow(jnp.abs(diff).sum(-1)/3, cmap='gray')\n",
1912
        "plt.colorbar()\n",
1913
        "\n",
1914
        "print(jnp.abs(diff[30:50, 80:90]).max())"
1915
      ]
1916
    },
1917
    {
1918
      "cell_type": "code",
1919
      "execution_count": null,
1920
      "metadata": {
1921
        "id": "udan6pb0Tk7b"
1922
      },
1923
      "outputs": [],
1924
      "source": []
1925
    },
1926
    {
1927
      "cell_type": "code",
1928
      "execution_count": null,
1929
      "metadata": {
1930
        "id": "CGoSY7aIRmpz"
1931
      },
1932
      "outputs": [],
1933
      "source": []
1934
    },
1935
    {
1936
      "cell_type": "code",
1937
      "execution_count": null,
1938
      "metadata": {
1939
        "collapsed": true,
1940
        "id": "vr0zZ_LavnOQ"
1941
      },
1942
      "outputs": [],
1943
      "source": [
1944
        "ind = 0\n",
1945
        "res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind] @ R, -rays_d_vec[ind], alpha_gt[ind]).reshape(H, W, 3)\n",
1946
        "#res = render(jnp.fliplr(envmap_gt), envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind] @ R.T, -rays_d_vec[ind], alpha_gt[ind]).reshape(H, W, 3)\n",
1947
        "# ???????????\n",
1948
        "plt.imshow(linear_to_srgb(res))\n",
1949
        "plt.figure()\n",
1950
        "plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1951
        "plt.figure()"
1952
      ]
1953
    },
1954
    {
1955
      "cell_type": "code",
1956
      "execution_count": null,
1957
      "metadata": {
1958
        "id": "XXPbbl_Bq07h"
1959
      },
1960
      "outputs": [],
1961
      "source": [
1962
        "ind = 20\n",
1963
        "#with open(f'{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_lowres_envmap_uniform_linear_128x128/test/r_{ind}.png', 'rb') as f:\n",
1964
        "with open(f'{DIRECTORY}/r_{ind}_envmap_nn.png', 'rb') as f:\n",
1965
        "  res_gt = np.float32(Image.open(f)) / 255.0\n",
1966
        "  alpha = res_gt[:, :, 3:]\n",
1967
        "  res_gt = res_gt[:, :, :3] * alpha\n",
1968
        "res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)\n",
1969
        "#print(res[60:70, 60:70, :])\n",
1970
        "res = linear_to_srgb(res)\n",
1971
        "plt.subplot(121)\n",
1972
        "plt.imshow(res)\n",
1973
        "plt.axis('off')\n",
1974
        "plt.subplot(122)\n",
1975
        "#plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
1976
        "plt.imshow(res_gt)\n",
1977
        "plt.axis('off')\n"
1978
      ]
1979
    },
1980
    {
1981
      "cell_type": "code",
1982
      "execution_count": null,
1983
      "metadata": {
1984
        "id": "Hrd8TAj9Hxc0"
1985
      },
1986
      "outputs": [],
1987
      "source": [
1988
        "#plt.imshow(3*(res - res_gt))\n",
1989
        "plt.imshow(0.5 * srgb_to_linear(res_gt) / srgb_to_linear(res))\n",
1990
        "plt.axis('off')"
1991
      ]
1992
    },
1993
    {
1994
      "cell_type": "code",
1995
      "execution_count": null,
1996
      "metadata": {
1997
        "id": "O_QxjQomHiSb"
1998
      },
1999
      "outputs": [],
2000
      "source": [
2001
        "plt.scatter(res[:, :, 0], res[:, :, 1])\n",
2002
        "plt.figure()\n",
2003
        "plt.scatter(res_gt[:, :, 0], res_gt[:, :, 1])"
2004
      ]
2005
    },
2006
    {
2007
      "cell_type": "code",
2008
      "execution_count": null,
2009
      "metadata": {
2010
        "id": "7XzYX89pn78B"
2011
      },
2012
      "outputs": [],
2013
      "source": [
2014
        "plt.imshow(res - res_gt, cmap='gray')\n"
2015
      ]
2016
    },
2017
    {
2018
      "cell_type": "code",
2019
      "execution_count": null,
2020
      "metadata": {
2021
        "id": "jXFCPk7dY7Eg"
2022
      },
2023
      "outputs": [],
2024
      "source": [
2025
        "plt.imshow((res - res_gt).sum(-1) / 3.0, cmap='gray')\n",
2026
        "plt.colorbar()"
2027
      ]
2028
    },
2029
    {
2030
      "cell_type": "code",
2031
      "execution_count": null,
2032
      "metadata": {
2033
        "id": "XbawH1gkzfyl"
2034
      },
2035
      "outputs": [],
2036
      "source": [
2037
        "ind = 70\n",
2038
        "#with open(f'{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_lowres_envmap_uniform_linear_128x128/test/r_{ind}.png', 'rb') as f:\n",
2039
        "with open(f'{DIRECTORY}/r_{ind}.png', 'rb') as f:\n",
2040
        "  res_gt = np.float32(Image.open(f)) / 255.0\n",
2041
        "  alpha = res_gt[:, :, 3:]\n",
2042
        "  res_gt = res_gt[:, :, :3] * alpha\n",
2043
        "res = render(envmap_gt*0.0 + jnp.where(jnp.cos(omega_phi.reshape(envmap_H, envmap_W, 1) + 0.0) \u003c 0, 1.0, 0.0),\n",
2044
        "             envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*1.0}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)\n",
2045
        "#print(res[60:70, 60:70, :])\n",
2046
        "res = linear_to_srgb(res)\n",
2047
        "plt.subplot(121)\n",
2048
        "plt.imshow(res, vmin=0.0, vmax=1.0)\n",
2049
        "plt.axis('off')\n",
2050
        "plt.subplot(122)\n",
2051
        "#plt.imshow(imgs_gt[ind].reshape(H, W, 3))\n",
2052
        "plt.imshow(res_gt, vmin=0.0, vmax=1.0)\n",
2053
        "plt.axis('off')\n",
2054
        "plt.figure()\n",
2055
        "diff = res - res_gt\n",
2056
        "plt.imshow(-diff[:, :, 0], cmap='gray')\n",
2057
        "plt.colorbar()"
2058
      ]
2059
    },
2060
    {
2061
      "cell_type": "code",
2062
      "execution_count": null,
2063
      "metadata": {
2064
        "id": "4pvhvCCPOqij"
2065
      },
2066
      "outputs": [],
2067
      "source": [
2068
        "res[70:90, 70:90, :].min(), res_gt[70:90, 70:90, :].min()"
2069
      ]
2070
    },
2071
    {
2072
      "cell_type": "code",
2073
      "execution_count": null,
2074
      "metadata": {
2075
        "id": "3nmFDWBzOzaj"
2076
      },
2077
      "outputs": [],
2078
      "source": [
2079
        "diff[60, 60, :]"
2080
      ]
2081
    },
2082
    {
2083
      "cell_type": "code",
2084
      "execution_count": null,
2085
      "metadata": {
2086
        "id": "rAaevaJN1b6E"
2087
      },
2088
      "outputs": [],
2089
      "source": []
2090
    },
2091
    {
2092
      "cell_type": "code",
2093
      "execution_count": null,
2094
      "metadata": {
2095
        "id": "M1qNaYMqeGHs"
2096
      },
2097
      "outputs": [],
2098
      "source": [
2099
        "a = alpha_gt[img_ind].reshape(H, W, 1)\n",
2100
        "plt.figure(figsize=[12, 12])\n",
2101
        "plt.subplot(121)\n",
2102
        "plt.imshow(res_srgb * a + (1.0 - a))\n",
2103
        "plt.axis('off')\n",
2104
        "plt.subplot(122)\n",
2105
        "plt.imshow(img)\n",
2106
        "plt.axis('off')"
2107
      ]
2108
    },
2109
    {
2110
      "cell_type": "code",
2111
      "execution_count": null,
2112
      "metadata": {
2113
        "id": "mXzlZ7xyeazs"
2114
      },
2115
      "outputs": [],
2116
      "source": [
2117
        "media.show_video([img_ggx, res_srgb], fps=2, height=200)"
2118
      ]
2119
    },
2120
    {
2121
      "cell_type": "code",
2122
      "execution_count": null,
2123
      "metadata": {
2124
        "id": "A3A_bf80cL79"
2125
      },
2126
      "outputs": [],
2127
      "source": [
2128
        "#with open('{DIRECTORY}/r_0.png', 'rb') as f:\n",
2129
        "#  img = np.array(Image.open(f))[:, :, :3] / 255.0\n",
2130
        "\n",
2131
        "with open('{DIRECTORY}/r_0_true_mirror.png', 'rb') as f:\n",
2132
        "  img_tm = np.array(Image.open(f))[:, :, :3] / 255.0\n",
2133
        "\n",
2134
        "with open('{DIRECTORY}/r_0_ggx_mirror.png', 'rb') as f:\n",
2135
        "  img_ggx = np.array(Image.open(f))[:, :, :3] / 255.0\n"
2136
      ]
2137
    },
2138
    {
2139
      "cell_type": "code",
2140
      "execution_count": null,
2141
      "metadata": {
2142
        "id": "vZuJPrZKdj-U"
2143
      },
2144
      "outputs": [],
2145
      "source": [
2146
        "plt.imshow(jnp.log10(jnp.abs(img - res_srgb).sum(-1)/3), cmap='gray')\n",
2147
        "plt.colorbar()"
2148
      ]
2149
    },
2150
    {
2151
      "cell_type": "code",
2152
      "execution_count": null,
2153
      "metadata": {
2154
        "id": "jtz1eDbGsQtR"
2155
      },
2156
      "outputs": [],
2157
      "source": []
2158
    },
2159
    {
2160
      "cell_type": "code",
2161
      "execution_count": null,
2162
      "metadata": {
2163
        "id": "TNpwJkd0sgvc"
2164
      },
2165
      "outputs": [],
2166
      "source": [
2167
        "# Strength: 0.5\n",
2168
        "img_lamb[64, 64]"
2169
      ]
2170
    },
2171
    {
2172
      "cell_type": "code",
2173
      "execution_count": null,
2174
      "metadata": {
2175
        "id": "GjmlQ241s7ij"
2176
      },
2177
      "outputs": [],
2178
      "source": [
2179
        "# Strength: 0.25\n",
2180
        "img_lamb[64, 64]"
2181
      ]
2182
    },
2183
    {
2184
      "cell_type": "code",
2185
      "execution_count": null,
2186
      "metadata": {
2187
        "id": "BhsBMDZzshXb"
2188
      },
2189
      "outputs": [],
2190
      "source": [
2191
        "# Strength: 0.2\n",
2192
        "img_lamb[64, 64]"
2193
      ]
2194
    },
2195
    {
2196
      "cell_type": "code",
2197
      "execution_count": null,
2198
      "metadata": {
2199
        "id": "5VexRaTAs2_T"
2200
      },
2201
      "outputs": [],
2202
      "source": [
2203
        "def linear_to_srgb(linear):\n",
2204
        "  srgb0 = 323 / 25 * linear\n",
2205
        "  srgb1 = (211 * linear**(5 / 12) - 11) / 200\n",
2206
        "  return np.where(linear \u003c= 0.0031308, srgb0, srgb1)\n",
2207
        "\n",
2208
        "\n",
2209
        "def srgb_to_linear(srgb):\n",
2210
        "  linear0 = srgb * 25 / 323\n",
2211
        "  linear1 = ((200 * srgb + 11) / 211) ** (12 / 5)\n",
2212
        "  return np.where(srgb \u003c= 0.0031308 * 25 / 323, linear0, linear1)\n",
2213
        "\n",
2214
        "srgb_to_linear(0.66666667), srgb_to_linear(0.48235294), srgb_to_linear(0.43529412)\n"
2215
      ]
2216
    },
2217
    {
2218
      "cell_type": "code",
2219
      "execution_count": null,
2220
      "metadata": {
2221
        "id": "JEJzvVIltsvM"
2222
      },
2223
      "outputs": [],
2224
      "source": [
2225
        "def linear_to_srgb(linear):\n",
2226
        "  srgb0 = 323 / 25 * linear\n",
2227
        "  srgb1 = (211 * linear**(5 / 12) - 11) / 200\n",
2228
        "  return np.where(linear \u003c= 0.0031308, srgb0, srgb1)\n",
2229
        "\n",
2230
        "print(linear_to_srgb(0.5))"
2231
      ]
2232
    },
2233
    {
2234
      "cell_type": "code",
2235
      "execution_count": null,
2236
      "metadata": {
2237
        "id": "ku7QqJbyFPBe"
2238
      },
2239
      "outputs": [],
2240
      "source": [
2241
        "ind = 0\n",
2242
        "with open(f'{DIRECTORY}/r_{ind}.exr', 'rb') as f:\n",
2243
        "  #res_gt = np.float32(Image.open(f)) / 255.0\n",
2244
        "  res_gt = imageio.imread(f, 'exr')\n",
2245
        "\n",
2246
        "  alpha = res_gt[:, :, 3:]\n",
2247
        "  res_gt = res_gt[:, :, :3] * alpha\n",
2248
        "\n",
2249
        "plt.imshow(res_gt, cmap='gray', interpolation='nearest')\n",
2250
        "\n",
2251
        "\n",
2252
        "res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)\n",
2253
        "res = res * alpha #res.repeat(3, 2) * alpha\n",
2254
        "plt.figure()\n",
2255
        "plt.imshow(res, cmap='gray', interpolation='nearest')\n",
2256
        "\n",
2257
        "plt.figure()\n",
2258
        "plt.imshow(jnp.abs(res - res_gt).sum(-1))\n",
2259
        "#plt.imshow(res[..., 1] - res_gt[..., 1])\n",
2260
        "plt.colorbar()"
2261
      ]
2262
    },
2263
    {
2264
      "cell_type": "code",
2265
      "execution_count": null,
2266
      "metadata": {
2267
        "id": "2js5LOhdUk01"
2268
      },
2269
      "outputs": [],
2270
      "source": [
2271
        "ind = 0\n",
2272
        "with open(f'{DIRECTORY}/r_{ind}_bg.exr', 'rb') as f:\n",
2273
        "  #res_gt = np.float32(Image.open(f)) / 255.0\n",
2274
        "  res_gt = imageio.imread(f, 'exr')\n",
2275
        "\n",
2276
        "  alpha = res_gt[:, :, 3:]\n",
2277
        "  res_gt = res_gt[:, :, :3] * alpha\n",
2278
        "\n",
2279
        "res_gt_srgb = linear_to_srgb(res_gt)\n",
2280
        "plt.imshow(res_gt_srgb, interpolation='nearest')\n",
2281
        "\n",
2282
        "\n",
2283
        "#omega_phi, omega_theta\n",
2284
        "dirs = rays_d_vec[ind].reshape(H, W, 3)\n",
2285
        "#elevation = omega_theta.reshape(envmap_H, envmap_W) / jnp.pi * (envmap_H - 1 + 1)\n",
2286
        "#azimuth = omega_phi.reshape(envmap_H, envmap_W) / (2.0 * jnp.pi) * (envmap_W - 1) + 0.5\n",
2287
        "#inds = jnp.stack([elevation, azimuth], axis=0)\n",
2288
        "#env = interp2d(envmap_gt, inds)\n",
2289
        "dirs_azimuth = np.arctan2(dirs[..., 1], dirs[..., 0])\n",
2290
        "nr = np.sqrt(dirs[..., 1] ** 2 + dirs[..., 0] ** 2)\n",
2291
        "dirs_elevation = np.arctan2(nr, dirs[..., 2])\n",
2292
        "import scipy\n",
2293
        "plt.figure()\n",
2294
        "channels = []\n",
2295
        "for i in range(3):\n",
2296
        "  interp = scipy.interpolate.interp2d(omega_phi.reshape(envmap_H, envmap_W)[0, :], omega_theta.reshape(envmap_H, envmap_W)[:, 0], envmap_gt[:, :, i])\n",
2297
        "  #ch = interp(omega_phi.reshape(envmap_H, envmap_W)[0, :], omega_theta.reshape(envmap_H, envmap_W)[:, 0])\n",
2298
        "  ch = interp(dirs_azimuth[0, :], dirs_elevation[:, 0])\n",
2299
        "  channels.append(ch)\n",
2300
        "#plt.imshow(interp(omega_phi, omega_theta).reshape(envmap_H, envmap_W))\n",
2301
        "plt.imshow(linear_to_srgb(jnp.stack(channels, axis=-1)))\n",
2302
        "#res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)\n",
2303
        "#res = res * alpha #res.repeat(3, 2) * alpha\n",
2304
        "#plt.figure()\n",
2305
        "#plt.imshow(res, cmap='gray', interpolation='nearest')\n",
2306
        "\n",
2307
        "#plt.figure()\n",
2308
        "#plt.imshow(jnp.abs(res - res_gt).sum(-1))\n",
2309
        "#plt.imshow(res[..., 1] - res_gt[..., 1])\n",
2310
        "#plt.colorbar()"
2311
      ]
2312
    },
2313
    {
2314
      "cell_type": "code",
2315
      "execution_count": null,
2316
      "metadata": {
2317
        "id": "h8DjVEBMuDwJ"
2318
      },
2319
      "outputs": [],
2320
      "source": [
2321
        "# When shifting by half a pixel this is the error"
2322
      ]
2323
    },
2324
    {
2325
      "cell_type": "code",
2326
      "execution_count": null,
2327
      "metadata": {
2328
        "id": "BhPEi4MioGj3"
2329
      },
2330
      "outputs": [],
2331
      "source": [
2332
        "plt.imshow(envmap_gt)"
2333
      ]
2334
    },
2335
    {
2336
      "cell_type": "code",
2337
      "execution_count": null,
2338
      "metadata": {
2339
        "id": "tndUYnK7kVjr"
2340
      },
2341
      "outputs": [],
2342
      "source": [
2343
        "res.min(), res.mean(), res.max()"
2344
      ]
2345
    },
2346
    {
2347
      "cell_type": "code",
2348
      "execution_count": null,
2349
      "metadata": {
2350
        "id": "FkrC_npUkgLk"
2351
      },
2352
      "outputs": [],
2353
      "source": [
2354
        "res_gt.min(), res_gt.mean(), res_gt.max()"
2355
      ]
2356
    },
2357
    {
2358
      "cell_type": "code",
2359
      "execution_count": null,
2360
      "metadata": {
2361
        "id": "TYbKAZiUgJLJ"
2362
      },
2363
      "outputs": [],
2364
      "source": [
2365
        "media.show_video([res, res_gt], fps=2, height=256)"
2366
      ]
2367
    },
2368
    {
2369
      "cell_type": "code",
2370
      "execution_count": null,
2371
      "metadata": {
2372
        "id": "jsEnZ94pl0Zq"
2373
      },
2374
      "outputs": [],
2375
      "source": [
2376
        "envmap_gt.min(), envmap_gt.max()"
2377
      ]
2378
    },
2379
    {
2380
      "cell_type": "code",
2381
      "execution_count": null,
2382
      "metadata": {
2383
        "id": "_O3aObadmSAB"
2384
      },
2385
      "outputs": [],
2386
      "source": [
2387
        "envmap_gt.min(), envmap_gt.max()"
2388
      ]
2389
    },
2390
    {
2391
      "cell_type": "code",
2392
      "execution_count": null,
2393
      "metadata": {
2394
        "id": "OTtZ2u2HuMvY"
2395
      },
2396
      "outputs": [],
2397
      "source": [
2398
        "ind = 1\n",
2399
        "with open(f'{DIRECTORY}/r_{ind}.exr', 'rb') as f:\n",
2400
        "  #res_gt = np.float32(Image.open(f)) / 255.0\n",
2401
        "  res_gt = imageio.imread(f, 'exr')\n",
2402
        "\n",
2403
        "  alpha = res_gt[:, :, 3:]\n",
2404
        "  res_gt = res_gt[:, :, :1] * alpha\n",
2405
        "\n",
2406
        "plt.imshow(res_gt, cmap='gray')\n",
2407
        "\n",
2408
        "res = jnp.maximum(0.0, (normals_gt[ind].reshape(H, W, 3) * jnp.array([0.0, 0.0, 1.0])[None, None, :]).sum(-1, keepdims=True)) / jnp.pi\n",
2409
        "res = res * alpha #res.repeat(3, 2) * alpha\n",
2410
        "plt.figure()\n",
2411
        "plt.imshow(res, cmap='gray')\n",
2412
        "\n",
2413
        "plt.figure()\n",
2414
        "plt.imshow(res - res_gt)\n",
2415
        "plt.colorbar()"
2416
      ]
2417
    },
2418
    {
2419
      "cell_type": "code",
2420
      "execution_count": null,
2421
      "metadata": {
2422
        "id": "GoZrZ53vCFqf"
2423
      },
2424
      "outputs": [],
2425
      "source": [
2426
        "plt.imshow(res_gt - res)"
2427
      ]
2428
    },
2429
    {
2430
      "cell_type": "code",
2431
      "execution_count": null,
2432
      "metadata": {
2433
        "id": "8E18U5PULGp_"
2434
      },
2435
      "outputs": [],
2436
      "source": [
2437
        "jnp.linalg.norm(normals_gt[0].reshape(H, W, 3), axis=-1).max()"
2438
      ]
2439
    },
2440
    {
2441
      "cell_type": "code",
2442
      "execution_count": null,
2443
      "metadata": {
2444
        "id": "mxThbxIwLOHT"
2445
      },
2446
      "outputs": [],
2447
      "source": [
2448
        "def foo(rays_d, rays_o, rad=1.0):\n",
2449
        "  d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)\n",
2450
        "  o_norm_sq = (rays_o ** 2).sum(-1, keepdims=True)\n",
2451
        "  d_dot_o = (rays_o * rays_d).sum(-1, keepdims=True)\n",
2452
        "  disc = d_norm_sq * (rad ** 2  - o_norm_sq) + d_dot_o ** 2\n",
2453
        "  alpha = jnp.float32(disc \u003e 0)\n",
2454
        "  t_surface = jnp.where(disc \u003e 0, - jnp.sqrt(disc) - d_dot_o, jnp.inf)  # [H, W, 1]\n",
2455
        "\n",
2456
        "  pts = rays_o + rays_d * t_surface\n",
2457
        "\n",
2458
        "  normals = pts / jnp.linalg.norm(pts, axis=-1, keepdims=True)\n",
2459
        "\n",
2460
        "  plt.imshow(normals.reshape(H, W, 3) * 0.5 + 0.5)\n",
2461
        "\n",
2462
        "ind = 111\n",
2463
        "foo(rays_d_vec[ind], rays_o_vec[ind])\n",
2464
        "plt.figure()\n",
2465
        "#R = \n",
2466
        "n = normals_gt[ind]\n",
2467
        "plt.imshow(n.reshape(H, W, 3) * 0.5 + 0.5)"
2468
      ]
2469
    },
2470
    {
2471
      "cell_type": "code",
2472
      "execution_count": null,
2473
      "metadata": {
2474
        "id": "BUx2MdOkIl2C"
2475
      },
2476
      "outputs": [],
2477
      "source": [
2478
        "rays_d_vec[-1].reshape(H, W, 3)[64, -1, :]"
2479
      ]
2480
    },
2481
    {
2482
      "cell_type": "code",
2483
      "execution_count": null,
2484
      "metadata": {
2485
        "id": "-HgV69GTssmP"
2486
      },
2487
      "outputs": [],
2488
      "source": [
2489
        "n = normals_gt[0].reshape(H, W, 3) #@ R.T\n",
2490
        "plt.imshow(n * 0.5 + 0.5)"
2491
      ]
2492
    },
2493
    {
2494
      "cell_type": "code",
2495
      "execution_count": null,
2496
      "metadata": {
2497
        "id": "lS6xAAxRxXeQ"
2498
      },
2499
      "outputs": [],
2500
      "source": []
2501
    },
2502
    {
2503
      "cell_type": "code",
2504
      "execution_count": null,
2505
      "metadata": {
2506
        "id": "aAzT6N1HtBHZ"
2507
      },
2508
      "outputs": [],
2509
      "source": [
2510
        "focal = .5 * W / np.tan(0.5 * 0.691111147403717)\n",
2511
        "pixtocams = camera_utils.get_pixtocam(focal, W, H)\n",
2512
        "c2w = jnp.array([[ 0.0, 1.0, 0.0, 0.0],\n",
2513
        "                 [-1.0, 0.0, 0.0, 0.0],\n",
2514
        "                 [ 0.0, 0.0, 1.0, 4.0],\n",
2515
        "                 [ 0.0, 0.0, 0.0, 1.0]])\n",
2516
        "origins, directions, _, _, _ = camera_utils.pixels_to_rays(\n",
2517
        "    jnp.array([W/2.0]),\n",
2518
        "    jnp.array([0]),\n",
2519
        "    pixtocams,\n",
2520
        "    c2w)\n",
2521
        "print(origins)\n",
2522
        "print(directions)  # 'up' is x"
2523
      ]
2524
    },
2525
    {
2526
      "cell_type": "code",
2527
      "execution_count": null,
2528
      "metadata": {
2529
        "id": "xtu32GGARIIj"
2530
      },
2531
      "outputs": [],
2532
      "source": []
2533
    },
2534
    {
2535
      "cell_type": "code",
2536
      "execution_count": null,
2537
      "metadata": {
2538
        "id": "k4fgXxvrRIqf"
2539
      },
2540
      "outputs": [],
2541
      "source": [
2542
        "rays_d_vec[0].reshape(H, W, 3)[0, 64, :]"
2543
      ]
2544
    },
2545
    {
2546
      "cell_type": "code",
2547
      "execution_count": null,
2548
      "metadata": {
2549
        "id": "RtguZTjrLclp"
2550
      },
2551
      "outputs": [],
2552
      "source": [
2553
        "surface_pts = rays_o_vec + t_surface_gt * rays_d_vec\n",
2554
        "print(surface_pts.shape)"
2555
      ]
2556
    },
2557
    {
2558
      "cell_type": "code",
2559
      "execution_count": null,
2560
      "metadata": {
2561
        "id": "aChrXPwo5bFD"
2562
      },
2563
      "outputs": [],
2564
      "source": [
2565
        "finite_surface_points.shape"
2566
      ]
2567
    },
2568
    {
2569
      "cell_type": "code",
2570
      "execution_count": null,
2571
      "metadata": {
2572
        "id": "KfdcZm1vxZUu"
2573
      },
2574
      "outputs": [],
2575
      "source": [
2576
        "@jax.jit\n",
2577
        "def get_dists_sq(x, y):\n",
2578
        "  return ((x - y) ** 2).sum(-1)\n",
2579
        "\n",
2580
        "def subsample_point_cloud(points, num_points):\n",
2581
        "  new_points = [points[0]]\n",
2582
        "  inds = []\n",
2583
        "  for i in range(num_points - 1):\n",
2584
        "    points_to_use_indices = np.random.choice(points.shape[0], size=(1000,), replace=False)\n",
2585
        "    if i % 100 == 0:\n",
2586
        "      print(i)\n",
2587
        "    dists = get_dists_sq(points[points_to_use_indices, None, :], jnp.stack(new_points, axis=0)[None, :, :])\n",
2588
        "    new_point_ind = jnp.argmax(dists.min(axis=1))\n",
2589
        "    new_points.append(points[points_to_use_indices[new_point_ind]])\n",
2590
        "    inds.append(points_to_use_indices[new_point_ind])\n",
2591
        "  return new_points, inds\n",
2592
        "\n",
2593
        "surface_pts = surface_pts.reshape(-1, 3)\n",
2594
        "finite_surface_points = surface_pts[jnp.all(jnp.isfinite(surface_pts), axis=-1)]\n",
2595
        "surface_pts_subsampled = subsample_point_cloud(finite_surface_points, 10000)\n",
2596
        "\n",
2597
        "surface_pts_subsampled = jnp.stack(surface_pts_subsampled, axis=0)"
2598
      ]
2599
    },
2600
    {
2601
      "cell_type": "code",
2602
      "execution_count": null,
2603
      "metadata": {
2604
        "id": "6s0ThFOJyx5l"
2605
      },
2606
      "outputs": [],
2607
      "source": [
2608
        "fig = plt.figure()\n",
2609
        "ax = fig.add_subplot(projection='3d')\n",
2610
        "\n",
2611
        "ax.scatter(surface_pts_subsampled[..., 0],\n",
2612
        "           surface_pts_subsampled[..., 1],\n",
2613
        "           surface_pts_subsampled[..., 2])\n"
2614
      ]
2615
    },
2616
    {
2617
      "cell_type": "code",
2618
      "execution_count": null,
2619
      "metadata": {
2620
        "id": "PUELUiCD5Oxu"
2621
      },
2622
      "outputs": [],
2623
      "source": [
2624
        "with open('{DIRECTORY}/hotdog_surface_pts_subsampled.npy', 'wb') as f:\n",
2625
        "  np.save(f, surface_pts_subsampled)\n"
2626
      ]
2627
    },
2628
    {
2629
      "cell_type": "code",
2630
      "execution_count": null,
2631
      "metadata": {
2632
        "id": "gSahd5OS0hws"
2633
      },
2634
      "outputs": [],
2635
      "source": [
2636
        "plt.plot(surface_pts_subsampled[..., 0], surface_pts_subsampled[..., 1], '.')"
2637
      ]
2638
    },
2639
    {
2640
      "cell_type": "code",
2641
      "execution_count": null,
2642
      "metadata": {
2643
        "id": "OrwX1dfmB45A"
2644
      },
2645
      "outputs": [],
2646
      "source": [
2647
        "surface_pts_subsampled.shape"
2648
      ]
2649
    },
2650
    {
2651
      "cell_type": "markdown",
2652
      "metadata": {
2653
        "id": "oYslVMJlH0Zs"
2654
      },
2655
      "source": [
2656
        "# Cache visibility using MLP\n",
2657
        "\n",
2658
        "Optimize an MLP mapping from position and direction to visibility:\n",
2659
        "$$(\\mathbf{x}, \\boldsymbol{\\omega}) \\mapsto v$$\n",
2660
        "where $v$ is a scalar visibility in $[0, 1]$ (constrained by a sigmoid), and the position and direction are 3-vectors with positional encoding."
2661
      ]
2662
    },
2663
    {
2664
      "cell_type": "code",
2665
      "execution_count": null,
2666
      "metadata": {
2667
        "id": "SCMrkNMuzyaM"
2668
      },
2669
      "outputs": [],
2670
      "source": [
2671
        "#with open('{DIRECTORY}/hotdog_surface_pts.npy', 'rb') as f:\n",
2672
        "#  surface_pts = np.load(f)\n"
2673
      ]
2674
    },
2675
    {
2676
      "cell_type": "code",
2677
      "execution_count": null,
2678
      "metadata": {
2679
        "id": "7CCAtFgjxwI2"
2680
      },
2681
      "outputs": [],
2682
      "source": []
2683
    },
2684
    {
2685
      "cell_type": "code",
2686
      "execution_count": null,
2687
      "metadata": {
2688
        "id": "b2MJEDxTxnfX"
2689
      },
2690
      "outputs": [],
2691
      "source": [
2692
        "surface_normals_subsampled.shape, omega_xyz.shape, occlusion_masks.shape"
2693
      ]
2694
    },
2695
    {
2696
      "cell_type": "code",
2697
      "execution_count": null,
2698
      "metadata": {
2699
        "id": "rwOrxx5Qbnqj"
2700
      },
2701
      "outputs": [],
2702
      "source": [
2703
        "with open('{DIRECTORY}/hotdog_surface_pts_subsampled.npy', 'rb') as f:\n",
2704
        "  surface_pts_subsampled = np.load(f)\n",
2705
        "\n",
2706
        "with open('{DIRECTORY}/subsampling_indices.npy', 'rb') as f:\n",
2707
        "  indices = np.load(f)\n",
2708
        "\n",
2709
        "with open('{DIRECTORY}/visibility_images.npy', 'rb') as f:\n",
2710
        "  occlusion_masks = jnp.float32(np.load(f)) / 255.0\n",
2711
        "\n",
2712
        "with open('{DIRECTORY}/hotdog_surface_normals_subsampled.npy', 'rb') as f:\n",
2713
        "  surface_normals_subsampled = np.load(f)\n",
2714
        "\n",
2715
        "envmap_H, envmap_W = occlusion_masks.shape[1:]\n",
2716
        "\n",
2717
        "omega_phi, omega_theta = jnp.meshgrid(jnp.linspace(-jnp.pi, jnp.pi, envmap_W+1)[:-1] + 2.0 * jnp.pi / (2.0 * envmap_W),\n",
2718
        "                                      jnp.linspace(0.0,     jnp.pi, envmap_H+1)[:-1] +       jnp.pi / (2.0 * envmap_H))\n",
2719
        "\n",
2720
        "dtheta_dphi = (omega_theta[1, 1] - omega_theta[0, 0]) * (omega_phi[1, 1] - omega_phi[0, 0])\n",
2721
        "\n",
2722
        "omega_theta = omega_theta.flatten()\n",
2723
        "omega_phi = omega_phi.flatten()\n",
2724
        "\n",
2725
        "omega_x = jnp.sin(omega_theta) * jnp.cos(omega_phi)\n",
2726
        "omega_y = jnp.sin(omega_theta) * jnp.sin(omega_phi)\n",
2727
        "omega_z = jnp.cos(omega_theta)\n",
2728
        "omega_xyz = jnp.stack([omega_x,\n",
2729
        "                       omega_y,\n",
2730
        "                       omega_z], axis=-1)\n",
2731
        "\n",
2732
        "\n",
2733
        "# Turn the negative hemisphere into nans\n",
2734
        "occlusion_masks = jnp.where(jnp.sum(surface_normals_subsampled[:, None, :] * omega_xyz[None, :, :], axis=-1).reshape(-1, envmap_H, envmap_W) \u003e 0.0,\n",
2735
        "                            occlusion_masks, jnp.nan)\n",
2736
        "#                            occlusion_masks, 0.0)\n"
2737
      ]
2738
    },
2739
    {
2740
      "cell_type": "code",
2741
      "execution_count": null,
2742
      "metadata": {
2743
        "id": "vhGs5Ip9x7W-"
2744
      },
2745
      "outputs": [],
2746
      "source": [
2747
        "plt.imshow(occlusion_masks[70])"
2748
      ]
2749
    },
2750
    {
2751
      "cell_type": "code",
2752
      "execution_count": null,
2753
      "metadata": {
2754
        "id": "C5YOTkoNR9mc"
2755
      },
2756
      "outputs": [],
2757
      "source": []
2758
    },
2759
    {
2760
      "cell_type": "code",
2761
      "execution_count": null,
2762
      "metadata": {
2763
        "id": "dZ8y1LMdB5cM"
2764
      },
2765
      "outputs": [],
2766
      "source": [
2767
        "class MLP(nn.Module):\n",
2768
        "  features: Sequence[int]\n",
2769
        "\n",
2770
        "  @nn.compact\n",
2771
        "  def __call__(self, x, y=None):\n",
2772
        "    if y is not None:\n",
2773
        "      x = jnp.concatenate([x, y], axis=-1)\n",
2774
        "    for feat in self.features[:-1]:\n",
2775
        "      x = nn.relu(nn.Dense(feat)(x))\n",
2776
        "    x = nn.Dense(self.features[-1])(x)\n",
2777
        "    return x\n",
2778
        "\n",
2779
        "\n",
2780
        "append_identity = True\n",
2781
        "def posenc(x, L_encoding):\n",
2782
        "  if L_encoding \u003c= 0:\n",
2783
        "    return x\n",
2784
        "  else:\n",
2785
        "    scales = 2**jnp.arange(L_encoding)\n",
2786
        "    #shape = x.shape[:-1] + (-1,)\n",
2787
        "    #scaled_x = jnp.reshape((x[..., None, :] * scales[:, None]), shape)\n",
2788
        "\n",
2789
        "    #four_feat = jnp.sin(\n",
2790
        "    #    jnp.concatenate([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1))\n",
2791
        "    shape = x.shape[:-1] + (-1,)\n",
2792
        "    scaled_x = x[..., None, :] * scales[:, None] # [..., L, D]\n",
2793
        "\n",
2794
        "    four_feat = jnp.sin(\n",
2795
        "        jnp.stack([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1)) # [..., L, D, 2]\n",
2796
        "\n",
2797
        "    #four_feat = jnp.reshape(four_feat / scales[:, None, None], shape)\n",
2798
        "    #print(\"Using Lipschitz posenc\")\n",
2799
        "    four_feat = jnp.reshape(four_feat, shape)\n",
2800
        "    if append_identity:\n",
2801
        "      return jnp.concatenate([x] + [four_feat], axis=-1)\n",
2802
        "    else:\n",
2803
        "      return four_feat\n",
2804
        "\n",
2805
        "\n",
2806
        "# Initialize material MLP\n",
2807
        "L_encoding_vis_x = 3\n",
2808
        "L_encoding_vis_dir = 6\n",
2809
        "mlp_input_features = 6 + 6 * (L_encoding_vis_x + L_encoding_vis_dir)\n",
2810
        "\n",
2811
        "num_components = 1\n",
2812
        "mlp_vis = MLP([128]*4 + [num_components])\n",
2813
        "\n",
2814
        "params_vis = mlp_vis.init(jax.random.PRNGKey(0),\n",
2815
        "                          np.zeros([1, mlp_input_features]))\n",
2816
        "\n",
2817
        "init_lr_vis = 1e-2\n",
2818
        "\n",
2819
        "init_vis, update_vis, get_params_vis = jax.experimental.optimizers.adam(init_lr_vis)\n",
2820
        "state_vis = init_vis(params_vis)\n",
2821
        "\n",
2822
        "#x_input = surface_pts_subsampled.reshape(-1, 1, 3).repeat(envmap_H * envmap_W, axis=1).reshape(-1, 3)\n",
2823
        "#dir_input = omega_xyz.reshape(1, -1, 3).repeat(surface_pts_subsampled.shape[0], axis=0).reshape(-1, 3)\n",
2824
        "#vis_gt = jnp.zeros(())\n",
2825
        "\n",
2826
        "\n",
2827
        "def get_vis_loss(params, x_input, dir_input, vis_gt):\n",
2828
        "\n",
2829
        "  vis_pred = jax.nn.sigmoid(mlp_vis.apply(params,\n",
2830
        "                                          posenc(x_input, L_encoding_vis_x),\n",
2831
        "                                          posenc(dir_input, L_encoding_vis_dir)))\n",
2832
        "\n",
2833
        "  loss = jnp.nansum((vis_pred - vis_gt) ** 2) / x_batch_size / envmap_H / envmap_W\n",
2834
        "\n",
2835
        "  return loss\n",
2836
        "\n",
2837
        "@jax.jit\n",
2838
        "def step(state, x_input, dir_input, vis_gt, i):\n",
2839
        "  params = get_params_vis(state)\n",
2840
        "  loss, grad = jax.value_and_grad(get_vis_loss)(params, x_input, dir_input, vis_gt)\n",
2841
        "\n",
2842
        "  return update_vis(i, grad, state), loss\n",
2843
        "\n",
2844
        "#get_vis_loss(state_vis)\n",
2845
        "\n",
2846
        "x_batch_size = 50\n",
2847
        "\n",
2848
        "losses = []\n",
2849
        "for i in range(500):\n",
2850
        "  image_indices = np.random.randint(0, surface_pts_subsampled.shape[0], size=(x_batch_size,))\n",
2851
        "  \n",
2852
        "  x_input = surface_pts_subsampled[image_indices].reshape(-1, 1, 3).repeat(envmap_H * envmap_W, axis=1).reshape(-1, 3)\n",
2853
        "  dir_input = omega_xyz.reshape(1, -1, 3).repeat(x_batch_size, axis=0).reshape(-1, 3)\n",
2854
        "  vis_gt = occlusion_masks[image_indices].reshape(-1, 1)\n",
2855
        "\n",
2856
        "  state_vis, loss = step(state_vis, x_input, dir_input, vis_gt, i)\n",
2857
        "\n",
2858
        "  losses.append(loss)\n",
2859
        "\n",
2860
        "plt.semilogy(losses)"
2861
      ]
2862
    },
2863
    {
2864
      "cell_type": "code",
2865
      "execution_count": null,
2866
      "metadata": {
2867
        "id": "JtrQf09xzths"
2868
      },
2869
      "outputs": [],
2870
      "source": [
2871
        "jnp.nansum(occlusion_masks * 2 - 1)"
2872
      ]
2873
    },
2874
    {
2875
      "cell_type": "code",
2876
      "execution_count": null,
2877
      "metadata": {
2878
        "id": "idm3yLifaMRD"
2879
      },
2880
      "outputs": [],
2881
      "source": [
2882
        "@jax.jit\n",
2883
        "def evaluate_ind(params, x_input, dir_input, ind):\n",
2884
        "  vis_pred = jax.nn.sigmoid(mlp_vis.apply(params,\n",
2885
        "                                          posenc(x_input, L_encoding_vis_x),\n",
2886
        "                                          posenc(dir_input, L_encoding_vis_dir)))\n",
2887
        "\n",
2888
        "  return vis_pred\n",
2889
        "\n",
2890
        "def evaluate(params):\n",
2891
        "  dir_input = omega_xyz\n",
2892
        "\n",
2893
        "  total_error = 0.0\n",
2894
        "  for ind in range(surface_pts_subsampled.shape[0]):\n",
2895
        "    if (ind + 1) % 1000 == 0:\n",
2896
        "      print(ind + 1)\n",
2897
        "    x_input = surface_pts_subsampled[ind] * jnp.ones_like(omega_xyz)\n",
2898
        "    mask = evaluate_ind(params, x_input, dir_input, ind).reshape(-1)\n",
2899
        "    hard_mask = jnp.float32(mask \u003e 0.5)\n",
2900
        "    mask_gt = occlusion_masks[ind].reshape(-1)\n",
2901
        "\n",
2902
        "    error = jnp.nansum(jnp.abs(hard_mask - mask_gt))\n",
2903
        "    total_error += error\n",
2904
        "\n",
2905
        "  return total_error, total_error / surface_pts_subsampled.shape[0]\n",
2906
        "\n",
2907
        "ind = 70\n",
2908
        "\n",
2909
        "x_input = surface_pts_subsampled[ind] * jnp.ones_like(omega_xyz)\n",
2910
        "dir_input = omega_xyz\n",
2911
        "\n",
2912
        "mask = evaluate_ind(get_params_vis(state_vis), x_input, dir_input, ind).reshape(envmap_H, envmap_W)\n",
2913
        "plt.imshow(mask)\n",
2914
        "plt.figure()\n",
2915
        "plt.imshow(mask \u003e 0.5)\n",
2916
        "plt.figure()\n",
2917
        "plt.imshow(occlusion_masks[ind])\n",
2918
        "\n",
2919
        "#error, avg_error = evaluate(get_params_vis(state_vis))\n",
2920
        "\n",
2921
        "print(avg_error)"
2922
      ]
2923
    },
2924
    {
2925
      "cell_type": "code",
2926
      "execution_count": null,
2927
      "metadata": {
2928
        "id": "tSf39C_s0bSn"
2929
      },
2930
      "outputs": [],
2931
      "source": [
2932
        "indices = []\n",
2933
        "for ind in range(10000):\n",
2934
        "  d = ((surface_pts - surface_pts_subsampled[ind]) ** 2).sum(-1)\n",
2935
        "  i, r, c = np.unravel_index(np.argmin(d), shape=surface_pts.shape[:3])\n",
2936
        "  indices.append((i, r, c))\n",
2937
        "\n",
2938
        "with open('{DIRECTORY}/subsampling_indices.npy', 'wb') as f:\n",
2939
        "  np.save(f, np.array(indices))\n"
2940
      ]
2941
    },
2942
    {
2943
      "cell_type": "code",
2944
      "execution_count": null,
2945
      "metadata": {
2946
        "id": "LDxFU8Yy01_N"
2947
      },
2948
      "outputs": [],
2949
      "source": [
2950
        "#with open('{DIRECTORY}/subsampling_indices.npy', 'wb') as f:\n",
2951
        "#  np.save(f, np.array(indices))\n"
2952
      ]
2953
    },
2954
    {
2955
      "cell_type": "code",
2956
      "execution_count": null,
2957
      "metadata": {
2958
        "id": "QtHRjceE6cvi"
2959
      },
2960
      "outputs": [],
2961
      "source": [
2962
        "with open('{DIRECTORY}/hotdog_surface_normals_subsampled.npy', 'wb') as f:\n",
2963
        "  np.save(f, surface_normals_subsampled)\n"
2964
      ]
2965
    },
2966
    {
2967
      "cell_type": "code",
2968
      "execution_count": null,
2969
      "metadata": {
2970
        "id": "5LCuTu1-mCnm"
2971
      },
2972
      "outputs": [],
2973
      "source": [
2974
        "normals_gt.shape, indices.shape"
2975
      ]
2976
    },
2977
    {
2978
      "cell_type": "code",
2979
      "execution_count": null,
2980
      "metadata": {
2981
        "id": "lmZ_tVAivWpW"
2982
      },
2983
      "outputs": [],
2984
      "source": [
2985
        "surface_normals_subsampled = normals_gt.reshape(-1, 128, 128, 3)[indices[:, 0], indices[:, 1], indices[:, 2], :]\n",
2986
        "surface_normals_subsampled.shape"
2987
      ]
2988
    },
2989
    {
2990
      "cell_type": "code",
2991
      "execution_count": null,
2992
      "metadata": {
2993
        "id": "iVk2YZZhvg5F"
2994
      },
2995
      "outputs": [],
2996
      "source": [
2997
        "for ind in [0, 10, 20, 50, 100]:\n",
2998
        "#ind = 10\n",
2999
        "  a = jnp.where(jnp.sum(surface_normals_subsampled[ind] * omega_xyz.reshape(envmap_H, envmap_W, 3), axis=-1) \u003e 0.0, occlusion_masks[ind], 0.0)\n",
3000
        "  \n",
3001
        "  plt.figure(); plt.imshow(a)"
3002
      ]
3003
    },
3004
    {
3005
      "cell_type": "code",
3006
      "execution_count": null,
3007
      "metadata": {
3008
        "id": "U29UKCZAwxC2"
3009
      },
3010
      "outputs": [],
3011
      "source": [
3012
        "indices[100]"
3013
      ]
3014
    },
3015
    {
3016
      "cell_type": "code",
3017
      "execution_count": null,
3018
      "metadata": {
3019
        "id": "MLrDOPsrwrVt"
3020
      },
3021
      "outputs": [],
3022
      "source": [
3023
        "surface_normals_subsampled[109]"
3024
      ]
3025
    },
3026
    {
3027
      "cell_type": "code",
3028
      "execution_count": null,
3029
      "metadata": {
3030
        "id": "tY84xcMswyGI"
3031
      },
3032
      "outputs": [],
3033
      "source": [
3034
        "normals_gt.reshape(512, 128, 128, 3)[327, 96, 49, :]"
3035
      ]
3036
    },
3037
    {
3038
      "cell_type": "code",
3039
      "execution_count": null,
3040
      "metadata": {
3041
        "id": "Ytescsnhvsag"
3042
      },
3043
      "outputs": [],
3044
      "source": [
3045
        "plt.imshow(occlusion_masks[ind])"
3046
      ]
3047
    },
3048
    {
3049
      "cell_type": "code",
3050
      "execution_count": null,
3051
      "metadata": {
3052
        "id": "HA2YrqMYwHFP"
3053
      },
3054
      "outputs": [],
3055
      "source": [
3056
        "# A simple script that uses blender to render views of a single object by rotation the camera around it.\n",
3057
        "# Also produces depth map at the same time.\n",
3058
        "\n",
3059
        "import argparse, sys, os\n",
3060
        "import json\n",
3061
        "import bpy\n",
3062
        "import mathutils\n",
3063
        "import numpy as np\n",
3064
        "        \n",
3065
        "def listify_matrix(matrix):\n",
3066
        "    matrix_list = []\n",
3067
        "    for row in matrix:\n",
3068
        "        matrix_list.append(list(row))\n",
3069
        "    return matrix_list\n",
3070
        "\n",
3071
        "def delistify_matrix(lst):\n",
3072
        "    mat = mathutils.Matrix()\n",
3073
        "    for i in range(4):\n",
3074
        "        for j in range(4):\n",
3075
        "            mat[i][j] = lst[i][j]\n",
3076
        "    return mat\n",
3077
        "\n",
3078
        "\n",
3079
        "DEBUG = False\n",
3080
        "envmap_H = 50\n",
3081
        "envmap_W = 99\n",
3082
        "FORMAT = 'PNG'\n",
3083
        "\n",
3084
        "# filename is /.../\u003cmodel\u003e.blend/.../\u003cscript\u003e.py\n",
3085
        "ind_f = __file__.find('.blend')\n",
3086
        "ind_i = __file__[:ind_f].rfind('/') + 1\n",
3087
        "#model_name = __file__[ind_i:ind_f] + '_uniform'\n",
3088
        "#model_name = 'hotdog_farfield_occlusions_lambertian_new_no_self_occ_uniform_linear_128x128'\n",
3089
        "model_name = 'hotdog_occlusions_lambertian_linear_128x128'\n",
3090
        "\n",
3091
        "# Read from file\n",
3092
        "#transforms_files = [f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_train.json',\n",
3093
        "#                    f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_test.json']\n",
3094
        "\n",
3095
        "partitions = ['train']\n",
3096
        "transforms_files = [f'/Users/dorverbin/Downloads/blend_files/{model_name}/transforms_{partition}.json' for partition in partitions]\n",
3097
        "RESULTS_PATH = os.path.join(model_name, 'visibility')\n",
3098
        "\n",
3099
        "fp = bpy.path.abspath(f\"//{RESULTS_PATH}\")\n",
3100
        "\n",
3101
        "if not os.path.exists(fp):\n",
3102
        "    os.makedirs(fp)\n",
3103
        "for partition in partitions:\n",
3104
        "    if not os.path.exists(os.path.join(fp, partition)):\n",
3105
        "        os.makedirs(os.path.join(fp, partition))\n",
3106
        "\n",
3107
        "# Data to store in JSON file\n",
3108
        "#out_data = {}\n",
3109
        "\n",
3110
        "\n",
3111
        "# Render Optimizations\n",
3112
        "bpy.context.scene.render.use_persistent_data = True\n",
3113
        "\n",
3114
        "\n",
3115
        "# Set up rendering of depth map.\n",
3116
        "bpy.context.scene.use_nodes = True\n",
3117
        "tree = bpy.context.scene.node_tree\n",
3118
        "links = tree.links\n",
3119
        "\n",
3120
        "# Add passes for additionally dumping albedo and normals.\n",
3121
        "bpy.context.scene.render.image_settings.file_format = str('PNG')\n",
3122
        "bpy.context.scene.render.image_settings.color_depth = str(8)\n",
3123
        "print(\"Only 32 if using EXR. When I use binary forget about it\")\n",
3124
        "\n",
3125
        "# If using OpenEXR, set to linear color space\n",
3126
        "bpy.data.scenes['Scene'].display_settings.display_device = 'None'\n",
3127
        "bpy.data.scenes['Scene'].sequencer_colorspace_settings.name = 'Linear'  \n",
3128
        "\n",
3129
        "# Remove all tree nodes\n",
3130
        "for node in tree.nodes:\n",
3131
        "    tree.nodes.remove(node)\n",
3132
        "\n",
3133
        "if 'Custom Outputs' not in tree.nodes:\n",
3134
        "    # Create input render layer node.\n",
3135
        "    render_layers = tree.nodes.new('CompositorNodeRLayers')\n",
3136
        "    render_layers.label = 'Custom Outputs'\n",
3137
        "    render_layers.name = 'Custom Outputs'\n",
3138
        "\n",
3139
        "# Background\n",
3140
        "bpy.context.scene.render.dither_intensity = 0.0\n",
3141
        "bpy.context.scene.render.film_transparent = False\n",
3142
        "\n",
3143
        "# Create collection for objects not to render with background\n",
3144
        "\n",
3145
        "\n",
3146
        "scene = bpy.context.scene\n",
3147
        "scene.render.resolution_x = envmap_W\n",
3148
        "scene.render.resolution_y = envmap_H\n",
3149
        "scene.render.resolution_percentage = 100\n",
3150
        "\n",
3151
        "cam = scene.objects['Camera']\n",
3152
        "\n",
3153
        "# Define equirect camera\n",
3154
        "cam.data.type = 'PANO'\n",
3155
        "cam.data.cycles.panorama_type = 'EQUIRECTANGULAR'\n",
3156
        "#cam.data.cycles.latitude_min = np.pi / (2.0 * envmap_H) - np.pi / 2.0\n",
3157
        "#cam.data.cycles.latitude_max = np.pi / 2.0 - np.pi / (2.0 * envmap_H)\n",
3158
        "\n",
3159
        "\n",
3160
        "#cam.location = (0, 4.0, 0.5)\n",
3161
        "\n",
3162
        "#cam_constraint = cam.constraints.new(type='TRACK_TO')\n",
3163
        "#cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'\n",
3164
        "#cam_constraint.up_axis = 'UP_Y'\n",
3165
        "#b_empty = parent_obj_to_camera(cam)\n",
3166
        "#cam_constraint.target = b_empty\n",
3167
        "\n",
3168
        "\n",
3169
        "#scene.render.image_settings.file_format = 'PNG'  # set output format to .png\n",
3170
        "scene.render.image_settings.file_format = FORMAT  # set output format to .png\n",
3171
        "\n",
3172
        "canonical_mat = [[ 0.0, 0.0, -1.0, 0.0],\n",
3173
        "                 [ 1.0, 0.0,  0.0, 0.0],\n",
3174
        "                 [ 0.0, 1.0,  0.0, 0.0],\n",
3175
        "                 [ 0.0, 0.0,  0.0, 1.0]]\n",
3176
        "\n",
3177
        "#hotdog_points = np.load('/Users/dorverbin/Downloads/hotdog_surface_pts.npy')\n",
3178
        "hotdog_points = np.load('/Users/dorverbin/Downloads/hotdog_surface_pts_subsampled.npy')\n",
3179
        "\n",
3180
        "all_object_names_except_camera = [k for k in scene.objects.keys() if 'Camera' not in k]\n",
3181
        "\n",
3182
        "def toggle_object_visibility(do_hide_objects):\n",
3183
        "    for o in all_object_names_except_camera:\n",
3184
        "        scene.objects[o].hide_render = do_hide_objects\n",
3185
        "    \n",
3186
        "def toggle_mask_visibility(do_hide_mask):\n",
3187
        "    bpy.data.worlds[\"World\"].node_tree.nodes[\"Math\"].inputs[1].default_value = 1.01 if do_hide_mask else 0.8\n",
3188
        "\n",
3189
        "\n",
3190
        "for transforms_file, partition in zip(transforms_files, partitions):\n",
3191
        "    if transforms_file is not None:\n",
3192
        "        with open(transforms_file) as in_file:\n",
3193
        "            transforms_data = json.load(in_file)\n",
3194
        "        \n",
3195
        "        VIEWS = len(transforms_data['frames'])\n",
3196
        "    else:\n",
3197
        "        raise RuntimeError('Must specify transforms file')\n",
3198
        "\n",
3199
        "    #out_data['frames'] = []\n",
3200
        "\n",
3201
        "    #for i in range(0, VIEWS, 20):\n",
3202
        "    #if partition == 'train':\n",
3203
        "    #    continue\n",
3204
        "    \"\"\"\n",
3205
        "    for i in range(1):\n",
3206
        "        #cam.matrix_world = delistify_matrix(transforms_data['frames'][i]['transform_matrix'])    \n",
3207
        "        #print(cam.matrix_world)\n",
3208
        "\n",
3209
        "        # Start by rendering mask\n",
3210
        "        cam.matrix_world = delistify_matrix(canonical_mat)\n",
3211
        "\n",
3212
        "        toggle_object_visibility(False)\n",
3213
        "        toggle_mask_visibility(True)\n",
3214
        "\n",
3215
        "\n",
3216
        "        for r in [40]:\n",
3217
        "            for c in range(30, 40):\n",
3218
        "                p = hotdog_points[i, r, c, :]\n",
3219
        "                if not np.all(np.isfinite(p)):\n",
3220
        "                    continue\n",
3221
        "                \n",
3222
        "                point_x, point_y, point_z = p\n",
3223
        "            \n",
3224
        "                cam.matrix_world[0][3] = point_x\n",
3225
        "                cam.matrix_world[1][3] = point_y\n",
3226
        "                cam.matrix_world[2][3] = point_z\n",
3227
        "                \n",
3228
        "                scene.render.filepath = os.path.join(fp, partition, f'r_{i}_{r}_{c}')          \n",
3229
        "                bpy.ops.render.render(write_still=True)  # render still\n",
3230
        "                \n",
3231
        "            \n",
3232
        "        toggle_object_visibility(True)\n",
3233
        "        toggle_mask_visibility(False)\n",
3234
        "        bpy.data.worlds[\"World\"].node_tree.nodes[\"Value\"].outputs[0].default_value = cam.matrix_world[0][3]\n",
3235
        "        bpy.data.worlds[\"World\"].node_tree.nodes[\"Value.001\"].outputs[0].default_value = cam.matrix_world[1][3]\n",
3236
        "        bpy.data.worlds[\"World\"].node_tree.nodes[\"Value.002\"].outputs[0].default_value = cam.matrix_world[2][3]                \n",
3237
        "        \n",
3238
        "        \n",
3239
        "        # Render mask\n",
3240
        "        scene.render.filepath = os.path.join(fp, partition, f'r_{i}_mask')\n",
3241
        "        if DEBUG:\n",
3242
        "            break\n",
3243
        "        else:\n",
3244
        "            bpy.ops.render.render(write_still=True)  # render still\n",
3245
        "    \"\"\"\n",
3246
        "\n",
3247
        "    toggle_object_visibility(False)\n",
3248
        "    toggle_mask_visibility(True)\n",
3249
        "\n",
3250
        "    #for ind in range(hotdog_points.shape[0]):\n",
3251
        "    for ind in [100]:    \n",
3252
        "        p = hotdog_points[ind, :]\n",
3253
        "        if not np.all(np.isfinite(p)):\n",
3254
        "            continue\n",
3255
        "        \n",
3256
        "        point_x, point_y, point_z = p\n",
3257
        "    \n",
3258
        "        cam.matrix_world[0][3] = point_x\n",
3259
        "        cam.matrix_world[1][3] = point_y\n",
3260
        "        cam.matrix_world[2][3] = point_z\n",
3261
        "        \n",
3262
        "        #scene.render.filepath = os.path.join(fp, partition, f'r_{ind}')          \n",
3263
        "        #bpy.ops.render.render(write_still=True)  # render still\n",
3264
        "        \n",
3265
        "\n",
3266
        "\n",
3267
        "\n",
3268
        "        #frame_data = {\n",
3269
        "        #    #'file_path': scene.render.filepath,\n",
3270
        "        #    'file_path': f'./{partition}/r_{i}',\n",
3271
        "        #    'transform_matrix': listify_matrix(cam.matrix_world)\n",
3272
        "        #}\n",
3273
        "        #out_data['frames'].append(frame_data)\n",
3274
        "\n",
3275
        "        #if transforms_file is None:\n",
3276
        "        #    b_empty.rotation_euler[0] = CIRCLE_FIXED_START[0] + (np.cos(radians(stepsize*i))+1)/2 * vertical_diff\n",
3277
        "        #    b_empty.rotation_euler[2] += radians(2*stepsize)\n",
3278
        "\n",
3279
        "    #if not DEBUG:\n",
3280
        "    #    with open(fp + '/' + f'transforms_{partition}.json', 'w') as out_file:\n",
3281
        "    #        json.dump(out_data, out_file, indent=4)\n",
3282
        "\n",
3283
        "\n"
3284
      ]
3285
    },
3286
    {
3287
      "cell_type": "code",
3288
      "execution_count": null,
3289
      "metadata": {
3290
        "id": "DvcjVwuz-Q1b"
3291
      },
3292
      "outputs": [],
3293
      "source": [
3294
        "# A simple script that uses blender to render views of a single object by rotation the camera around it.\n",
3295
        "# Also produces depth map at the same time.\n",
3296
        "\n",
3297
        "import argparse, sys, os\n",
3298
        "import json\n",
3299
        "import bpy\n",
3300
        "import mathutils\n",
3301
        "import numpy as np\n",
3302
        "        \n",
3303
        "\n",
3304
        "def listify_matrix(matrix):\n",
3305
        "    matrix_list = []\n",
3306
        "    for row in matrix:\n",
3307
        "        matrix_list.append(list(row))\n",
3308
        "    return matrix_list\n",
3309
        "\n",
3310
        "def delistify_matrix(lst):\n",
3311
        "    mat = mathutils.Matrix()\n",
3312
        "    for i in range(4):\n",
3313
        "        for j in range(4):\n",
3314
        "            mat[i][j] = lst[i][j]\n",
3315
        "    return mat\n",
3316
        "      \n",
3317
        "def parse_bin(s):\n",
3318
        "  return int(s[1:], 2) / 2.**(len(s) - 1)\n",
3319
        "\n",
3320
        "\n",
3321
        "def phi2(i):\n",
3322
        "  return parse_bin('.' + f'{i:b}'[::-1])\n",
3323
        "\n",
3324
        "def nice_uniform(N):\n",
3325
        "  u = []\n",
3326
        "  v = []\n",
3327
        "  for i in range(N):\n",
3328
        "    u.append(i / float(N))\n",
3329
        "    v.append(phi2(i))\n",
3330
        "    #pts.append((i/float(N), phi2(i)))\n",
3331
        "\n",
3332
        "  return u, v\n",
3333
        "\n",
3334
        "def nice_uniform_spherical(N, hemisphere=True):\n",
3335
        "  \"\"\"implementation of http://holger.dammertz.org/stuff/notes_HammersleyOnHemisphere.html\"\"\"\n",
3336
        "  u, v = nice_uniform(N)\n",
3337
        "\n",
3338
        "  theta = np.arccos(1.0 - np.array(u)) * (2.0 - int(hemisphere))\n",
3339
        "  phi   = 2.0 * np.pi * np.array(v)\n",
3340
        "\n",
3341
        "  return theta, phi\n",
3342
        "    \n",
3343
        "    \n",
3344
        "    \n",
3345
        "hemisphere = True\n",
3346
        "camera_dist = np.sqrt(4.0**2 + 0.5**2)\n",
3347
        "def get_all_camera_matrices(N_cameras, camera_dist=camera_dist):\n",
3348
        "  theta, phi = nice_uniform_spherical(N_cameras, hemisphere)\n",
3349
        "\n",
3350
        "  camera_x_vec = np.sin(theta) * np.cos(phi)\n",
3351
        "  camera_y_vec = np.sin(theta) * np.sin(phi)\n",
3352
        "  camera_z_vec = np.cos(theta)\n",
3353
        "\n",
3354
        "  cameras = []\n",
3355
        "  for i in range(N_cameras):\n",
3356
        "    camera = np.eye(4)\n",
3357
        "    camera[0, 3] = camera_x_vec[i] * camera_dist\n",
3358
        "    camera[1, 3] = camera_y_vec[i] * camera_dist\n",
3359
        "    camera[2, 3] = camera_z_vec[i] * camera_dist\n",
3360
        "\n",
3361
        "    zdir = np.array([camera_x_vec[i], camera_y_vec[i], camera_z_vec[i]])\n",
3362
        "    zdir /= np.linalg.norm(zdir)\n",
3363
        "\n",
3364
        "    ydir = np.array([0.0, 0.0, 1.0])\n",
3365
        "    ydir -= zdir * zdir.dot(ydir)\n",
3366
        "    ydir[0] += 1e-10  # make sure that cameras pointing straight down/up have a defined ydir\n",
3367
        "    ydir /= np.linalg.norm(ydir)\n",
3368
        "\n",
3369
        "    xdir = np.cross(ydir, zdir)\n",
3370
        "\n",
3371
        "    camera[:3, 0] = xdir\n",
3372
        "    camera[:3, 1] = ydir\n",
3373
        "    camera[:3, 2] = zdir\n",
3374
        "    \n",
3375
        "    cameras.append(camera)\n",
3376
        "  return cameras\n",
3377
        "         \n",
3378
        "DEBUG = False\n",
3379
        "VIEWS = 512  # Only used if not specifying transforms_file\n",
3380
        "RESOLUTION = 128 #800\n",
3381
        "DEPTH_SCALE = 1.4\n",
3382
        "FORMAT = 'OPEN_EXR'\n",
3383
        "COLOR_DEPTH = 8 if FORMAT == 'PNG' else 32\n",
3384
        " \n",
3385
        "       \n",
3386
        "# filename is /.../\u003cmodel\u003e.blend/.../\u003cscript\u003e.py\n",
3387
        "ind_f = __file__.find('.blend')\n",
3388
        "ind_i = __file__[:ind_f].rfind('/') + 1\n",
3389
        "model_name = __file__[ind_i:ind_f] + '_uniform'\n",
3390
        "\n",
3391
        "\n",
3392
        "if FORMAT == 'OPEN_EXR':\n",
3393
        "    RESULTS_PATH = f'{model_name}_linear'\n",
3394
        "elif FORMAT == 'PNG':\n",
3395
        "    RESULTS_PATH = model_name\n",
3396
        "else:\n",
3397
        "    raise RuntimeError('format unknown')\n",
3398
        "\n",
3399
        "if RESOLUTION != 800:\n",
3400
        "    RESULTS_PATH += f'_{RESOLUTION}x{RESOLUTION}'\n",
3401
        "\n",
3402
        "# Read from file\n",
3403
        "#transforms_files = [f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_train.json',\n",
3404
        "#                    f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_test.json']\n",
3405
        "#partitions = ['train', 'test']\n",
3406
        "transforms_files = [None]\n",
3407
        "partitions = ['occlusions']\n",
3408
        "\n",
3409
        "\n",
3410
        "fp = bpy.path.abspath(f\"//{RESULTS_PATH}\")\n",
3411
        "\n",
3412
        "if not os.path.exists(fp):\n",
3413
        "    os.makedirs(fp)\n",
3414
        "for partition in partitions:\n",
3415
        "    if not os.path.exists(os.path.join(fp, partition)):\n",
3416
        "        os.makedirs(os.path.join(fp, partition))\n",
3417
        "\n",
3418
        "# Data to store in JSON file\n",
3419
        "out_data = {\n",
3420
        "    'camera_angle_x': bpy.data.objects['Camera'].data.angle_x,\n",
3421
        "}\n",
3422
        "\n",
3423
        "# Render Optimizations\n",
3424
        "bpy.context.scene.render.use_persistent_data = True\n",
3425
        "\n",
3426
        "\n",
3427
        "# Set up rendering of depth map.\n",
3428
        "bpy.context.scene.use_nodes = True\n",
3429
        "tree = bpy.context.scene.node_tree\n",
3430
        "links = tree.links\n",
3431
        "\n",
3432
        "# Add passes for additionally dumping albedo and normals.\n",
3433
        "bpy.context.scene.view_layers[\"RenderLayer\"].use_pass_normal = True\n",
3434
        "bpy.context.scene.render.image_settings.file_format = str(FORMAT)\n",
3435
        "bpy.context.scene.render.image_settings.color_depth = str(COLOR_DEPTH)\n",
3436
        "\n",
3437
        "# If using OpenEXR, set to linear color space\n",
3438
        "if FORMAT == 'OPEN_EXR':\n",
3439
        "    bpy.data.scenes['Scene'].display_settings.display_device = 'None'\n",
3440
        "    bpy.data.scenes['Scene'].sequencer_colorspace_settings.name = 'Linear'\n",
3441
        "else:\n",
3442
        "    bpy.data.scenes['Scene'].display_settings.display_device = 'sRGB'\n",
3443
        "    bpy.data.scenes['Scene'].sequencer_colorspace_settings.name = 'sRGB'    \n",
3444
        "\n",
3445
        "# Remove all tree nodes\n",
3446
        "for node in tree.nodes:\n",
3447
        "    tree.nodes.remove(node)\n",
3448
        "\n",
3449
        "if 'Custom Outputs' not in tree.nodes:\n",
3450
        "    # Create input render layer node.\n",
3451
        "    render_layers = tree.nodes.new('CompositorNodeRLayers')\n",
3452
        "    render_layers.label = 'Custom Outputs'\n",
3453
        "    render_layers.name = 'Custom Outputs'\n",
3454
        "    \n",
3455
        "    depth_file_output = tree.nodes.new(type=\"CompositorNodeOutputFile\")\n",
3456
        "    depth_file_output.label = 'Depth Output'\n",
3457
        "    depth_file_output.name = 'Depth Output'\n",
3458
        "    if FORMAT == 'OPEN_EXR':\n",
3459
        "      add_one = tree.nodes.new('CompositorNodeMath')\n",
3460
        "      add_one.operation = 'ADD'\n",
3461
        "      add_one.inputs[1].default_value = 1.0\n",
3462
        "      links.new(render_layers.outputs['Depth'], add_one.inputs[0])\n",
3463
        "      \n",
3464
        "      recip = tree.nodes.new('CompositorNodeMath')\n",
3465
        "      recip.operation = 'DIVIDE'\n",
3466
        "      recip.inputs[0].default_value = 1.0\n",
3467
        "      links.new(add_one.outputs[0], recip.inputs[1])\n",
3468
        "      \n",
3469
        "      links.new(recip.outputs[0], depth_file_output.inputs[0])\n",
3470
        "      \n",
3471
        "    else:\n",
3472
        "      # Remap as other types can not represent the full range of depth.\n",
3473
        "      map = tree.nodes.new(type=\"CompositorNodeMapRange\")\n",
3474
        "      # Size is chosen kind of arbitrarily, try out until you're satisfied with resulting depth map.\n",
3475
        "      map.inputs['From Min'].default_value = 0\n",
3476
        "      map.inputs['From Max'].default_value = 8\n",
3477
        "      map.inputs['To Min'].default_value = 1\n",
3478
        "      map.inputs['To Max'].default_value = 0\n",
3479
        "      links.new(render_layers.outputs['Depth'], map.inputs[0])\n",
3480
        "\n",
3481
        "      links.new(map.outputs[0], depth_file_output.inputs[0])\n",
3482
        "    \n",
3483
        "    normal_file_output = tree.nodes.new(type=\"CompositorNodeOutputFile\")\n",
3484
        "    normal_file_output.label = 'Normal Output'\n",
3485
        "    normal_file_output.name = 'Normal Output'\n",
3486
        "    normal_file_output.format.file_format = 'PNG'\n",
3487
        "    \n",
3488
        "    # Separate normals into channels, transform (x+1)/2 and combine\n",
3489
        "    sep_rgba = tree.nodes.new('CompositorNodeSepRGBA')\n",
3490
        "    links.new(render_layers.outputs['Normal'], sep_rgba.inputs[0])\n",
3491
        "    \n",
3492
        "    comb_rgba = tree.nodes.new('CompositorNodeCombRGBA')\n",
3493
        "    add_ones = []\n",
3494
        "    divide_by_twos = []\n",
3495
        "    for i in range(3):\n",
3496
        "      add_ones.append(tree.nodes.new('CompositorNodeMath'))\n",
3497
        "      add_ones[i].operation = 'ADD'\n",
3498
        "      add_ones[i].inputs[1].default_value = 1.0\n",
3499
        "      links.new(sep_rgba.outputs[i], add_ones[i].inputs[0])        \n",
3500
        "    \n",
3501
        "      divide_by_twos.append(tree.nodes.new('CompositorNodeMath'))\n",
3502
        "      divide_by_twos[i].operation = 'DIVIDE'\n",
3503
        "      divide_by_twos[i].inputs[1].default_value = 2.0\n",
3504
        "      links.new(add_ones[i].outputs[0], divide_by_twos[i].inputs[0])        \n",
3505
        "        \n",
3506
        "      links.new(divide_by_twos[i].outputs[0], comb_rgba.inputs[i])\n",
3507
        "      \n",
3508
        "    # Connect alpha\n",
3509
        "    links.new(sep_rgba.outputs[3], comb_rgba.inputs[3])\n",
3510
        "    \n",
3511
        "    links.new(comb_rgba.outputs[0], normal_file_output.inputs[0])\n",
3512
        "\n",
3513
        "# Background\n",
3514
        "bpy.context.scene.render.dither_intensity = 0.0\n",
3515
        "bpy.context.scene.render.film_transparent = True\n",
3516
        "\n",
3517
        "# Create collection for objects not to render with background\n",
3518
        "\n",
3519
        "    \n",
3520
        "objs = [ob for ob in bpy.context.scene.objects if ob.type in ('EMPTY') and 'Empty' in ob.name]\n",
3521
        "bpy.ops.object.delete({\"selected_objects\": objs})\n",
3522
        "\n",
3523
        "def parent_obj_to_camera(b_camera):\n",
3524
        "    origin = (0, 0, 0)\n",
3525
        "    b_empty = bpy.data.objects.new(\"Empty\", None)\n",
3526
        "    b_empty.location = origin\n",
3527
        "    b_camera.parent = b_empty  # setup parenting\n",
3528
        "\n",
3529
        "    scn = bpy.context.scene\n",
3530
        "    scn.collection.objects.link(b_empty)\n",
3531
        "    bpy.context.view_layer.objects.active = b_empty\n",
3532
        "    # scn.objects.active = b_empty\n",
3533
        "    return b_empty\n",
3534
        "\n",
3535
        "\n",
3536
        "scene = bpy.context.scene\n",
3537
        "scene.render.resolution_x = RESOLUTION\n",
3538
        "scene.render.resolution_y = RESOLUTION\n",
3539
        "scene.render.resolution_percentage = 100\n",
3540
        "\n",
3541
        "cam = scene.objects['Camera']\n",
3542
        "#cam.location = (0, 4.0, 0.5)\n",
3543
        "\n",
3544
        "#cam_constraint = cam.constraints.new(type='TRACK_TO')\n",
3545
        "#cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'\n",
3546
        "#cam_constraint.up_axis = 'UP_Y'\n",
3547
        "#b_empty = parent_obj_to_camera(cam)\n",
3548
        "#cam_constraint.target = b_empty\n",
3549
        "\n",
3550
        "\n",
3551
        "#scene.render.image_settings.file_format = 'PNG'  # set output format to .png\n",
3552
        "scene.render.image_settings.file_format = FORMAT  # set output format to .png\n",
3553
        "\n",
3554
        "from math import radians\n",
3555
        "\n",
3556
        "stepsize = 360.0 / VIEWS\n",
3557
        "rotation_mode = 'XYZ'\n",
3558
        "\n",
3559
        "\n",
3560
        "\n",
3561
        "\n",
3562
        "if not DEBUG:\n",
3563
        "    for output_node in [tree.nodes['Depth Output'], tree.nodes['Normal Output']]:\n",
3564
        "        output_node.base_path = ''\n",
3565
        "\n",
3566
        "for transforms_file, partition in zip(transforms_files, partitions):\n",
3567
        "    if transforms_file is not None:\n",
3568
        "        with open(transforms_file) as in_file:\n",
3569
        "            transforms_data = json.load(in_file)\n",
3570
        "        \n",
3571
        "        VIEWS = len(transforms_data['frames'])\n",
3572
        "    else:\n",
3573
        "        cameras = get_all_camera_matrices(VIEWS)\n",
3574
        "\n",
3575
        "    out_data['frames'] = []\n",
3576
        "\n",
3577
        "    #for i in range(0, VIEWS, 20):\n",
3578
        "    #if partition == 'train':\n",
3579
        "    #    continue\n",
3580
        "    print(VIEWS)\n",
3581
        "    for i in [350]:\n",
3582
        "        if transforms_file is None:\n",
3583
        "            for a in range(4):\n",
3584
        "                for b in range(4):\n",
3585
        "                    cam.matrix_world[a][b] = cameras[i][a][b]\n",
3586
        "            print(cameras[i], i)\n",
3587
        "            #cam.location.x = cameras[i][0][3]\n",
3588
        "            #cam.location.y = cameras[i][1][3]\n",
3589
        "            #cam.location.z = cameras[i][2][3]\n",
3590
        "            #if RANDOM_VIEWS:\n",
3591
        "            #    scene.render.filepath = fp + '/r_' + str(i)\n",
3592
        "            #    b_empty.rotation_euler = np.random.uniform(0, 2*np.pi, size=3)\n",
3593
        "            #else:\n",
3594
        "            #    print(\"Rotation {}, {}\".format((stepsize * i), radians(stepsize * i)))\n",
3595
        "            #    scene.render.filepath = fp + '/r_{0:03d}'.format(int(i * stepsize))\n",
3596
        "        else:\n",
3597
        "            cam.matrix_world = delistify_matrix(transforms_data['frames'][i]['transform_matrix'])    \n",
3598
        "            print(cam.matrix_world)\n",
3599
        "            #peter.matrix_world = cam.matrix_world\n",
3600
        "        if DEBUG:\n",
3601
        "            i = np.random.randint(0,VIEWS)\n",
3602
        "            b_empty.rotation_euler[0] = CIRCLE_FIXED_START[0] + (np.cos(radians(stepsize*i))+1)/2 * vertical_diff\n",
3603
        "            b_empty.rotation_euler[2] += radians(2*stepsize*i)\n",
3604
        "       \n",
3605
        "        bpy.data.worlds[\"World\"].node_tree.nodes[\"Value\"].outputs[0].default_value = cam.matrix_world[0][3]\n",
3606
        "        bpy.data.worlds[\"World\"].node_tree.nodes[\"Value.001\"].outputs[0].default_value = cam.matrix_world[1][3]\n",
3607
        "        bpy.data.worlds[\"World\"].node_tree.nodes[\"Value.002\"].outputs[0].default_value = cam.matrix_world[2][3]                \n",
3608
        "        \n",
3609
        "        print(\"Rotation {}, {}\".format((stepsize * i), radians(stepsize * i)))\n",
3610
        "        scene.render.filepath = os.path.join(fp, partition, f'r_{i}')\n",
3611
        "\n",
3612
        "        tree.nodes['Depth Output'].file_slots[0].path = scene.render.filepath + \"_disp_\"\n",
3613
        "        tree.nodes['Normal Output'].file_slots[0].path = scene.render.filepath + \"_normal_\"\n",
3614
        "\n",
3615
        "        break\n",
3616
        "        if DEBUG:\n",
3617
        "            break\n",
3618
        "        else:\n",
3619
        "            bpy.ops.render.render(write_still=True)  # render still\n",
3620
        "\n",
3621
        "        frame_data = {\n",
3622
        "            #'file_path': scene.render.filepath,\n",
3623
        "            'file_path': f'./{partition}/r_{i}',\n",
3624
        "            'rotation': radians(stepsize),\n",
3625
        "            'transform_matrix': listify_matrix(cam.matrix_world)\n",
3626
        "        }\n",
3627
        "        out_data['frames'].append(frame_data)\n",
3628
        "\n",
3629
        "        #if transforms_file is None:\n",
3630
        "        #    b_empty.rotation_euler[0] = CIRCLE_FIXED_START[0] + (np.cos(radians(stepsize*i))+1)/2 * vertical_diff\n",
3631
        "        #    b_empty.rotation_euler[2] += radians(2*stepsize)\n",
3632
        "\n",
3633
        "    if not DEBUG:\n",
3634
        "        with open(fp + '/' + f'transforms_{partition}.json', 'w') as out_file:\n",
3635
        "            json.dump(out_data, out_file, indent=4)\n"
3636
      ]
3637
    }
3638
  ],
3639
  "metadata": {
3640
    "colab": {
3641
      "collapsed_sections": [],
3642
      "last_runtime": {
3643
        "build_target": "//googlex/gcam/buff/mipnerf360:notebook",
3644
        "kind": "private"
3645
      },
3646
      "name": "spherical_peter_problem_blender.ipynb",
3647
      "private_outputs": true,
3648
      "provenance": [
3649
        {
3650
          "file_id": "1Zys9YweuFO3a-1IbJNqMEYFcMgcI5-kH",
3651
          "timestamp": 1658639749060
3652
        }
3653
      ]
3654
    },
3655
    "kernelspec": {
3656
      "display_name": "Python 3",
3657
      "name": "python3"
3658
    },
3659
    "language_info": {
3660
      "name": "python"
3661
    }
3662
  },
3663
  "nbformat": 4,
3664
  "nbformat_minor": 0
3665
}
3666

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.