google-research

Форк
0
361 строка · 13.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Function to train the rendering model."""
17

18
import functools
19
import json
20
import os
21
from typing import Any, Callable, Tuple
22

23
from absl import logging
24

25
from clu import metric_writers
26
from clu import metrics
27
from clu import periodic_actions
28
import flax
29
import flax.jax_utils as flax_utils
30
import flax.linen as nn
31
from flax.training import checkpoints
32
from flax.training import train_state
33
import jax
34
import jax.numpy as jnp
35
import ml_collections
36
import numpy as np
37
import tensorflow as tf
38

39
from gen_patch_neural_rendering.src import datasets
40
from gen_patch_neural_rendering.src import models
41
from gen_patch_neural_rendering.src.utils import data_types
42
from gen_patch_neural_rendering.src.utils import file_utils
43
from gen_patch_neural_rendering.src.utils import model_utils
44
from gen_patch_neural_rendering.src.utils import render_utils
45
from gen_patch_neural_rendering.src.utils import train_utils
46

47

48
def train_step(
49
    model, rng, state,
50
    batch, learning_rate_fn,
51
    weight_decay,
52
    config):
53
  """Perform a single train step.
54

55
  Args:
56
    model: Flax module for the model. The apply method must take input images
57
      and a boolean argument indicating whether to use training or inference
58
      mode.
59
    rng: random number generator.
60
    state: State of the model (optimizer and state).
61
    batch: Training inputs for this step.
62
    learning_rate_fn: Function that computes the learning rate given the step
63
      number.
64
    weight_decay: Weighs L2 regularization term.
65
    config: experiment config dict.
66

67
  Returns:
68
    The new model state and dictionary with metrics.
69
  """
70
  logging.info("train_step(batch=%s)", batch)
71

72
  step = state.step + 1
73
  lr = learning_rate_fn(step)
74
  rng, key_0, key_1 = jax.random.split(rng, 3)
75

76
  def loss_fn(params):
77
    variables = {"params": params}
78
    ret = model.apply(
79
        variables, key_0, key_1, batch, randomized=config.model.randomized)
80
    if len(ret) not in (1, 2):
81
      raise ValueError(
82
          "ret should contain either 1 set of output (coarse only), or 2 sets"
83
          "of output (coarse as ret[0] and fine as ret[1]).")
84
    #------------------------------------------------------------------------
85
    # Main prediction
86
    # The main prediction is always at the end of the ret list.
87
    rgb, unused_disp, unused_acc = ret[-1]
88
    batch_pixels = model_utils.uint2float(batch.target_view.rgb)
89
    loss = ((rgb - batch_pixels[Ellipsis, :3])**2).mean()
90
    psnr = model_utils.compute_psnr(loss)
91

92
    #------------------------------------------------------------------------
93
    # Coarse / Regularization Prediction
94
    if len(ret) > 1:
95
      # If there are both coarse and fine predictions, we compute the loss for
96
      # the coarse prediction (ret[0]) as well.
97
      rgb_c, unused_disp_c, unused_acc_c = ret[0]
98
      loss_c = ((rgb_c - batch_pixels[Ellipsis, :3])**2).mean()
99
      psnr_c = model_utils.compute_psnr(loss_c)
100
    else:
101
      loss_c = 0.
102
      psnr_c = 0.
103

104
    #------------------------------------------------------------------------
105
    # Weight Regularization
106
    weight_penalty_params = jax.tree_leaves(variables["params"])
107
    weight_l2 = sum(
108
        [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1])
109
    weight_penalty = weight_decay * 0.5 * weight_l2
110

111
    #------------------------------------------------------------------------
112
    # Compute total loss and wrap the stats
113
    total_loss = loss + loss_c + weight_penalty
114
    stats = train_utils.Stats(
115
        loss=loss, psnr=psnr, loss_c=loss_c, psnr_c=psnr_c, weight_l2=weight_l2)
116
    return total_loss, stats
117

118
  #------------------------------------------------------------------------
119
  # Compute Graidents
120
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
121
  (loss, stats), grad = grad_fn(state.params)
122

123
  # Compute average gradient across multiple workers.
124
  grad = jax.lax.pmean(grad, axis_name="batch")
125

126
  #------------------------------------------------------------------------
127
  # Update States
128
  new_state = state.apply_gradients(grads=grad)
129

130
  metrics_update = train_utils.TrainMetrics.gather_from_model_output(
131
      total_loss=loss,
132
      loss=stats.loss,
133
      psnr=stats.psnr,
134
      loss_c=stats.loss_c,
135
      psnr_c=stats.psnr_c,
136
      weight_l2=stats.weight_l2,
137
      learning_rate=lr)
138
  return new_state, metrics_update, rng
139

140

141
def eval_step(state, rng, batch,
142
              render_pfn, config):
143
  """Compute the metrics for the given model in inference mode.
144

145
  The model is applied to the inputs with train=False using all devices on the
146
  host. Afterwards metrics are averaged across *all* devices (of all hosts).
147
  Args:
148
    state: Replicate model state.
149
    rng: random number generator.
150
    batch: data_types.Batch. Inputs that should be evaluated.
151
    render_pfn: pmaped render function.
152
    config: exepriment config.
153

154
  Returns:
155
    Dictionary of the replicated metrics.
156
  """
157
  logging.info("eval_step=================")
158
  variables = {
159
      "params": jax.device_get(jax.tree_map(lambda x: x[0], state)).params,
160
  }
161
  pred_color, pred_disp, pred_acc = render_utils.render_image(
162
      functools.partial(render_pfn, variables),
163
      batch,
164
      rng,
165
      render_utils.normalize_disp(config.dataset.name),
166
      chunk=config.eval.chunk)
167

168
  return pred_color, pred_disp, pred_acc
169

170

171
def train_and_evaluate(config, workdir):
172
  """Runs a training and evaluation loop.
173

174
  Args:
175
    config: Configuration to use.
176
    workdir: Working directory for checkpoints and TF summaries. If this
177
      contains checkpoint training will be resumed from the latest checkpoint.
178
  """
179
  if config.dataset.batch_size % jax.device_count() != 0:
180
    raise ValueError("Batch size must be divisible by the number of devices.")
181

182
  tf.io.gfile.makedirs(workdir)
183
  # Deterministic training.
184
  rng = jax.random.PRNGKey(config.seed)
185
  # Shift the numpy random seed by process_index() to shuffle data loaded
186
  # by different hosts
187
  np.random.seed(20201473 + jax.process_index())
188

189
  #----------------------------------------------------------------------------
190
  # Build input pipeline.
191
  rng, data_rng = jax.random.split(rng)
192
  data_rng = jax.random.fold_in(data_rng, jax.process_index())
193

194
  scene_path_list = train_utils.get_train_scene_list(config)
195

196
  train_ds = datasets.create_train_dataset(config, scene_path_list[0])
197
  _, eval_ds_dict = datasets.create_eval_dataset(config)
198
  _, eval_ds = eval_ds_dict.popitem()
199
  example_batch = train_ds.peek()
200

201
  #----------------------------------------------------------------------------
202
  # Learning rate schedule.
203
  num_train_steps = config.train.max_steps
204
  if num_train_steps == -1:
205
    num_train_steps = train_ds.size()
206
  steps_per_epoch = num_train_steps // config.train.num_epochs
207
  logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
208
               steps_per_epoch)
209

210
  learning_rate_fn = train_utils.create_learning_rate_fn(config)
211

212
  #----------------------------------------------------------------------------
213
  # Initialize model.
214
  rng, model_rng = jax.random.split(rng)
215
  model, state = models.create_train_state(
216
      config,
217
      model_rng,
218
      learning_rate_fn=learning_rate_fn,
219
      example_batch=example_batch,
220
  )
221

222
  #----------------------------------------------------------------------------
223
  # Set up checkpointing of the model and the input pipeline.
224

225
  # check if the job was stopped and relaunced
226
  latest_ckpt = checkpoints.latest_checkpoint(workdir)
227
  if latest_ckpt is None:
228
    # No previous checkpoint. Then check for pretrained weights.
229
    if config.train.pretrain_dir:
230
      state = checkpoints.restore_checkpoint(config.train.pretrain_dir, state)
231
  else:
232
    state = checkpoints.restore_checkpoint(workdir, state)
233

234
  initial_step = int(state.step) + 1
235
  step_per_scene = config.train.switch_scene_iter
236
  if config.dev_run:
237
    jnp.set_printoptions(precision=2)
238
    np.set_printoptions(precision=2)
239
    step_per_scene = 3
240

241
  #----------------------------------------------------------------------------
242
  # Distribute training.
243
  state = flax_utils.replicate(state)
244
  p_train_step = jax.pmap(
245
      functools.partial(
246
          train_step,
247
          model=model,
248
          learning_rate_fn=learning_rate_fn,
249
          weight_decay=config.train.weight_decay,
250
          config=config,
251
      ),
252
      axis_name="batch",
253
  )
254

255
  # Get distributed rendering function
256
  render_pfn = render_utils.get_render_function(
257
      model=model,
258
      config=config,
259
      randomized=False,  # No randomization for evaluation.
260
  )
261

262
  #----------------------------------------------------------------------------
263
  # Prepare Metric Writers
264
  writer = metric_writers.create_default_writer(
265
      workdir, just_logging=jax.process_index() > 0)
266
  if initial_step == 1:
267
    writer.write_hparams(dict(config))
268

269
  logging.info("Starting training loop at step %d.", initial_step)
270
  hooks = []
271
  report_progress = periodic_actions.ReportProgress(
272
      num_train_steps=num_train_steps, writer=writer)
273
  if jax.process_index() == 0:
274
    hooks += [
275
        report_progress,
276
    ]
277
  train_metrics = None
278

279
  # Prefetch_buffer_size = 6 x batch_size
280
  ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)
281
  n_local_devices = jax.local_device_count()
282
  rng = rng + jax.process_index()  # Make random seed separate across hosts.
283
  keys = jax.random.split(rng, n_local_devices)  # For pmapping RNG keys.
284

285
  with metric_writers.ensure_flushes(writer):
286
    for step in range(initial_step, num_train_steps + 1):
287
      # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
288
      # devices.
289
      if step % step_per_scene == 0:
290
        scene_idx = np.random.randint(len(scene_path_list))
291
        logging.info("Loading scene {}".format(scene_path_list[scene_idx]))  # pylint: disable=logging-format-interpolation
292
        curr_scene = scene_path_list[scene_idx]
293
        if config.dataset.name == "dtu":
294
          # lighting can take values between 0 and 6 (both included)
295
          config.dataset.dtu_light_idx = np.random.randint(low=0, high=7)
296
        train_ds = datasets.create_train_dataset(config, curr_scene)
297
        ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)
298

299
      is_last_step = step == num_train_steps
300
      with jax.profiler.StepTraceAnnotation("train", step_num=step):
301
        batch = next(ptrain_ds)
302
        state, metrics_update, keys = p_train_step(
303
            rng=keys, state=state, batch=batch)
304
        metric_update = flax_utils.unreplicate(metrics_update)
305
        train_metrics = (
306
            metric_update
307
            if train_metrics is None else train_metrics.merge(metric_update))
308
      # Quick indication that training is happening.
309
      logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
310
      for h in hooks:
311
        h(step)
312

313
      if step % config.train.log_loss_every_steps == 0 or is_last_step:
314
        writer.write_scalars(step, train_metrics.compute())
315
        train_metrics = None
316

317
      if step % config.train.render_every_steps == 0 or is_last_step:
318
        test_batch = next(eval_ds)
319
        test_pixels = model_utils.uint2float(
320
            test_batch.target_view.rgb)  # extract for evaluation
321
        with report_progress.timed("eval"):
322
          pred_color, pred_disp, pred_acc = eval_step(state, keys[0],
323
                                                      test_batch, render_pfn,
324
                                                      config)
325
        #------------------------------------------------------------------
326
        # Log metrics and images for host 0
327
        #------------------------------------------------------------------
328
        if jax.process_index() == 0:
329
          psnr = model_utils.compute_psnr(
330
              ((pred_color - test_pixels)**2).mean())
331
          ssim = 0.
332
          writer.write_scalars(step, {
333
              "train_eval/test_psnr": psnr,
334
              "train_eval/test_ssim": ssim,
335
          })
336
          writer.write_images(
337
              step, {
338
                  "test_pred_color": pred_color[None, :],
339
                  "test_target": test_pixels[None, :]
340
              })
341
          if pred_disp is not None:
342
            writer.write_images(step, {"test_pred_disp": pred_disp[None, :]})
343
          if pred_acc is not None:
344
            writer.write_images(step, {"test_pred_acc": pred_acc[None, :]})
345
        #------------------------------------------------------------------
346

347
      if (jax.process_index()
348
          == 0) and (step % config.train.checkpoint_every_steps == 0 or
349
                     is_last_step):
350
        # Write final metrics to file
351
        with file_utils.open_file(
352
            os.path.join(workdir, "train_logs.json"), "w") as f:
353
          log_dict = metric_update.compute()
354
          for k, v in log_dict.items():
355
            log_dict[k] = v.item()
356
          f.write(json.dumps(log_dict))
357
        with report_progress.timed("checkpoint"):
358
          state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
359
          checkpoints.save_checkpoint(workdir, state_to_save, step, keep=100)
360

361
  logging.info("Finishing training at step %d", num_train_steps)
362

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.