google-research

Форк
0
/
interactive-triplane-colab.ipynb 
353 строки · 11.7 Кб
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": [],
7
      "gpuType": "T4"
8
    },
9
    "kernelspec": {
10
      "name": "python3",
11
      "display_name": "Python 3"
12
    },
13
    "language_info": {
14
      "name": "python"
15
    },
16
    "accelerator": "GPU",
17
    "gpuClass": "standard"
18
  },
19
  "cells": [
20
    {
21
      "cell_type": "markdown",
22
      "source": [
23
        "Copyright 2023 Google LLC. SPDX-License-Identifier: Apache-2.0\n",
24
        "\n",
25
        "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
26
        "\n",
27
        "https://www.apache.org/licenses/LICENSE-2.0\n",
28
        "\n",
29
        "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
30
      ],
31
      "metadata": {
32
        "id": "K8hG1oJ_lHmo"
33
      }
34
    },
35
    {
36
      "cell_type": "markdown",
37
      "source": [
38
        "# Persistent Nature: Interactive Widget - Triplane Model"
39
      ],
40
      "metadata": {
41
        "id": "Ybs-1iqxlVQc"
42
      }
43
    },
44
    {
45
      "cell_type": "markdown",
46
      "source": [
47
        "## Download Dependencies"
48
      ],
49
      "metadata": {
50
        "id": "8U5LxiVjlXWF"
51
      }
52
    },
53
    {
54
      "cell_type": "code",
55
      "source": [
56
        "! pip install ninja --quiet\n",
57
        "! pip install git+https://github.com/davidbau/baukit --quiet\n",
58
        "! pip install git+https://github.com/openai/CLIP.git --quiet\n",
59
        "! pip install einops --quiet"
60
      ],
61
      "metadata": {
62
        "id": "Vsiw9sEFlYVZ"
63
      },
64
      "execution_count": null,
65
      "outputs": []
66
    },
67
    {
68
      "cell_type": "code",
69
      "source": [
70
        "! git clone https://github.com/google-research/google-research.git"
71
      ],
72
      "metadata": {
73
        "id": "qkKADkRxlbzk"
74
      },
75
      "execution_count": null,
76
      "outputs": []
77
    },
78
    {
79
      "cell_type": "code",
80
      "source": [
81
        "%cd google-research/persistent-nature"
82
      ],
83
      "metadata": {
84
        "id": "Z3bRYPwGlfHT"
85
      },
86
      "execution_count": null,
87
      "outputs": []
88
    },
89
    {
90
      "cell_type": "code",
91
      "source": [
92
        "%%capture\n",
93
        "! bash patch.sh\n",
94
        "! bash download.sh"
95
      ],
96
      "metadata": {
97
        "id": "1xv6dnpdlglc"
98
      },
99
      "execution_count": null,
100
      "outputs": []
101
    },
102
    {
103
      "cell_type": "code",
104
      "source": [
105
        "! ls"
106
      ],
107
      "metadata": {
108
        "id": "e5IUMyQ5liTa"
109
      },
110
      "execution_count": null,
111
      "outputs": []
112
    },
113
    {
114
      "cell_type": "markdown",
115
      "source": [
116
        "## Setup models"
117
      ],
118
      "metadata": {
119
        "id": "FSUP-tZjlkQ3"
120
      }
121
    },
122
    {
123
      "cell_type": "code",
124
      "source": [
125
        "import torch\n",
126
        "from PIL import Image\n",
127
        "import pickle\n",
128
        "import numpy as np\n",
129
        "from IPython.display import display\n",
130
        "from ipywidgets import HTML, Button, HBox, VBox, Layout\n",
131
        "from baukit import renormalize\n",
132
        "from models.triplane import model_full\n",
133
        "from utils import sky_util, soat_util_triplane, camera_util, noise_util\n"
134
      ],
135
      "metadata": {
136
        "id": "qpY04qJglkj9"
137
      },
138
      "execution_count": null,
139
      "outputs": []
140
    },
141
    {
142
      "cell_type": "code",
143
      "source": [
144
        "torch.set_grad_enabled(False)\n",
145
        "device = 'cuda'"
146
      ],
147
      "metadata": {
148
        "id": "BpYbqENDl8kO"
149
      },
150
      "execution_count": null,
151
      "outputs": []
152
    },
153
    {
154
      "cell_type": "code",
155
      "source": [
156
        "full_model = model_full.ModelFull('pretrained/model_triplane.pkl', 'pretrained/model_sky_360.pkl').to(device).eval()"
157
      ],
158
      "metadata": {
159
        "id": "HVcePJ84m_wE"
160
      },
161
      "execution_count": null,
162
      "outputs": []
163
    },
164
    {
165
      "cell_type": "code",
166
      "source": [
167
        "G = soat_util_triplane.init_soat_model(full_model.ground).eval().cuda()\n",
168
        "G_pano = full_model.sky.G\n",
169
        "grid = sky_util.make_grid(G_pano)\n",
170
        "input_layer = G_pano.synthesis.input\n",
171
        "\n",
172
        "# settings\n",
173
        "fov = 60\n",
174
        "box_warp = G.rendering_kwargs['box_warp']\n",
175
        "G.rendering_kwargs['ray_end'] *= 2\n",
176
        "G.rendering_kwargs['depth_resolution'] *= 2\n",
177
        "G.rendering_kwargs['depth_resolution_importance'] *= 2\n",
178
        "G.rendering_kwargs['y_clip'] = 8.0\n",
179
        "G.rendering_kwargs['decay_start'] = 0.9 * G.rendering_kwargs['ray_end']\n",
180
        "G.rendering_kwargs['sample_deterministic'] = True"
181
      ],
182
      "metadata": {
183
        "id": "E369mphFnBHu"
184
      },
185
      "execution_count": null,
186
      "outputs": []
187
    },
188
    {
189
      "cell_type": "markdown",
190
      "source": [
191
        "## Generate Initial Layout and Skydome Env Map"
192
      ],
193
      "metadata": {
194
        "id": "JIcWD-eknLH3"
195
      }
196
    },
197
    {
198
      "cell_type": "code",
199
      "source": [
200
        "seed = 10 # np.random.randint(0, 1000)\n",
201
        "grid_size = 5\n",
202
        "zs, c = soat_util_triplane.prepare_zs(seed, grid_h=grid_size, grid_w=grid_size)\n",
203
        "zs = soat_util_triplane.interpolate_zs(zs)\n",
204
        "\n",
205
        "# generate feature planes\n",
206
        "xz_soat = soat_util_triplane.generate_xz(zs, c) # [1, 32, 512, 512]\n",
207
        "xy_soat = soat_util_triplane.generate_xy(zs, c) # 2 x [1, 32, 256, 512]\n",
208
        "yz_soat = soat_util_triplane.generate_yz(zs, c) # 2 x [1, 32, 256, 512]\n",
209
        "planes = [xy_soat, xz_soat, yz_soat]\n",
210
        "\n",
211
        "# set up upsampler and sky inputs\n",
212
        "z = zs[0,0] # extract a z latent for the upsampler\n",
213
        "ws = soat_util_triplane.prepare_ws(z, torch.zeros_like(c))\n",
214
        "sky_z = z[:, : G_pano.z_dim]\n",
215
        "\n",
216
        "# rendered noise (may not be used depending on noise_mode for upsampler)\n",
217
        "noise_gen = noise_util.build_soat_noise(G, grid_size)\n",
218
        "noise_input = noise_gen.get_noise(batch_size=1, device=zs.device)"
219
      ],
220
      "metadata": {
221
        "id": "PdFAVIa5nL-S"
222
      },
223
      "execution_count": null,
224
      "outputs": []
225
    },
226
    {
227
      "cell_type": "code",
228
      "source": [
229
        "# How fast we adjust. Too large and it will overshoot.\n",
230
        "# Too small and it will not react in time to avoid mountains.\n",
231
        "tilt_velocity_scale = .3    # Keep this small, otherwise you'll get motion sickness.\n",
232
        "offset_velocity_scale = .5\n",
233
        "\n",
234
        "# How far up the image should the horizon be, ideally.\n",
235
        "# Suggested range: 0.5 to 0.7.\n",
236
        "horizon_target = 0.65\n",
237
        "\n",
238
        "# What proportion of the depth map should be \"near\" the camera, ideally.\n",
239
        "# The smaller the number, the higher up the camera will fly.\n",
240
        "# Suggested range: 0.05 to 0.2\n",
241
        "near_target = 0.2\n",
242
        "\n",
243
        "offset = 0.\n",
244
        "tilt = 0.\n",
245
        "initial_stabilize_frames = 10\n",
246
        "\n",
247
        "# sample a random camera\n",
248
        "sampled_camera, cam2world_matrix, intrinsics = soat_util_triplane.sample_random_camera(fov, box_warp, seed)\n",
249
        "intrinsics_matrix = intrinsics[None].to(device)\n",
250
        "\n",
251
        "# balance camera above the horizon\n",
252
        "for _ in range(10):\n",
253
        "    adjusted_cam = camera_util.adjust_camera_vertically(sampled_camera, offset, tilt)\n",
254
        "    outputs, horizon, near = soat_util_triplane.generate_frame(\n",
255
        "        G, adjusted_cam, planes, ws, intrinsics_matrix, noise_input)\n",
256
        "    tilt += tilt_velocity_scale*(horizon - horizon_target)\n",
257
        "    offset += offset_velocity_scale*(near - near_target)\n",
258
        "print(adjusted_cam)\n",
259
        "\n",
260
        "# generate sky texture\n",
261
        "img_w_gray_sky = outputs['image_w_gray_sky']\n",
262
        "sky_encode = full_model.sky.encode(img_w_gray_sky)\n",
263
        "start_grid = sky_util.generate_start_grid(seed, input_layer, grid)\n",
264
        "sky_texture = sky_util.generate_pano_transform(G_pano, sky_z, sky_encode, start_grid)\n",
265
        "sky_texture = sky_texture.cuda()[None]\n",
266
        "display(renormalize.as_image(sky_texture[0]))"
267
      ],
268
      "metadata": {
269
        "id": "oqbpzSLVnNZT"
270
      },
271
      "execution_count": null,
272
      "outputs": []
273
    },
274
    {
275
      "cell_type": "markdown",
276
      "source": [
277
        "## Interactive Widget"
278
      ],
279
      "metadata": {
280
        "id": "G3eWHnwOnVW7"
281
      }
282
    },
283
    {
284
      "cell_type": "code",
285
      "source": [
286
        "l = HTML(\"\")\n",
287
        "h = HTML(\"\")\n",
288
        "display_size = (256, 256)\n",
289
        "\n",
290
        "\n",
291
        "layout_params = Layout(width='80px', height='40px')\n",
292
        "words = ['', 'forward', '', 'left', 'reset', 'right', '', 'backward', '']\n",
293
        "items = [Button(description=w, layout = layout_params) for w in words]\n",
294
        "top_box = HBox(items[:3])\n",
295
        "mid_box = HBox(items[3:6])\n",
296
        "bottom_box = HBox(items[6:])\n",
297
        "arrows = VBox([top_box, mid_box, bottom_box])\n",
298
        "\n",
299
        "\n",
300
        "camera = adjusted_cam\n",
301
        "camera_util.INITIAL_CAMERA = adjusted_cam\n",
302
        "h.value = str(camera)\n",
303
        "\n",
304
        "\n",
305
        "def update_display(outputs, camera):\n",
306
        "    composite_rgb_url = renormalize.as_url(outputs['composite'][0], size=display_size)\n",
307
        "\n",
308
        "\n",
309
        "    # calculate xyz points\n",
310
        "    ray_origins, ray_directions = G.ray_sampler(outputs['cam2world_matrix'], intrinsics_matrix, 32)\n",
311
        "    t_val = torch.linspace(G.rendering_kwargs['ray_start'], G.rendering_kwargs['ray_end'], 100, device=device).view(1, 1, -1, 1)\n",
312
        "    xyz = (ray_origins.unsqueeze(-2) + t_val * ray_directions.unsqueeze(-2))\n",
313
        "    vis_rays =  camera_util.visualize_rays(G, outputs['world2cam_matrix'], xyz,\n",
314
        "                                       xz_soat, display_size[0])\n",
315
        "    cam_img = renormalize.as_image(vis_rays)\n",
316
        "    cam_url = renormalize.as_url(cam_img, size=display_size)\n",
317
        "    img_html = ('<div class=\"row\"> <img src=\"%s\"/> <img src=\"%s\"/> </div>' % (composite_rgb_url, cam_url))\n",
318
        "    l.value = img_html\n",
319
        "    h.value = str(camera)\n",
320
        "\n",
321
        "def handle_event(event):\n",
322
        "    global camera, offset, tilt\n",
323
        "    camera = camera_util.update_camera(camera, event['key'], auto_adjust_height_and_tilt=True)\n",
324
        "    c = camera_util.adjust_camera_vertically(camera, offset, tilt)\n",
325
        "    outputs, horizon, near = soat_util_triplane.generate_frame(\n",
326
        "        G, c, planes, ws, intrinsics_matrix, noise_input, sky_texture=sky_texture)\n",
327
        "    tilt += tilt_velocity_scale*(horizon - horizon_target)\n",
328
        "    offset += offset_velocity_scale*(near - near_target)\n",
329
        "    update_display(outputs, c)\n",
330
        "\n",
331
        "def on_button_clicked(b):\n",
332
        "    clicked = b.description\n",
333
        "    options = {'forward': 'w', 'backward': 's', 'left': 'a',\n",
334
        "               'right': 'd', 'reset': 'x'}\n",
335
        "    val = options.get(clicked)\n",
336
        "    if val:\n",
337
        "        handle_event({'key': val})\n",
338
        "\n",
339
        "\n",
340
        "for button in items:\n",
341
        "    button.on_click(on_button_clicked)\n",
342
        "\n",
343
        "display(h, HBox([l, arrows]))\n",
344
        "handle_event({'key': 'x'})"
345
      ],
346
      "metadata": {
347
        "id": "X2xvcjZmnP4B"
348
      },
349
      "execution_count": null,
350
      "outputs": []
351
    }
352
  ]
353
}
354

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.