google-research
361 строка · 13.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Function to train the rendering model."""
17
18import functools19import json20import os21from typing import Any, Callable, Tuple22
23from absl import logging24
25from clu import metric_writers26from clu import metrics27from clu import periodic_actions28import flax29import flax.jax_utils as flax_utils30import flax.linen as nn31from flax.training import checkpoints32from flax.training import train_state33import jax34import jax.numpy as jnp35import ml_collections36import numpy as np37import tensorflow as tf38
39from gen_patch_neural_rendering.src import datasets40from gen_patch_neural_rendering.src import models41from gen_patch_neural_rendering.src.utils import data_types42from gen_patch_neural_rendering.src.utils import file_utils43from gen_patch_neural_rendering.src.utils import model_utils44from gen_patch_neural_rendering.src.utils import render_utils45from gen_patch_neural_rendering.src.utils import train_utils46
47
48def train_step(49model, rng, state,50batch, learning_rate_fn,51weight_decay,52config):53"""Perform a single train step.54
55Args:
56model: Flax module for the model. The apply method must take input images
57and a boolean argument indicating whether to use training or inference
58mode.
59rng: random number generator.
60state: State of the model (optimizer and state).
61batch: Training inputs for this step.
62learning_rate_fn: Function that computes the learning rate given the step
63number.
64weight_decay: Weighs L2 regularization term.
65config: experiment config dict.
66
67Returns:
68The new model state and dictionary with metrics.
69"""
70logging.info("train_step(batch=%s)", batch)71
72step = state.step + 173lr = learning_rate_fn(step)74rng, key_0, key_1 = jax.random.split(rng, 3)75
76def loss_fn(params):77variables = {"params": params}78ret = model.apply(79variables, key_0, key_1, batch, randomized=config.model.randomized)80if len(ret) not in (1, 2):81raise ValueError(82"ret should contain either 1 set of output (coarse only), or 2 sets"83"of output (coarse as ret[0] and fine as ret[1]).")84#------------------------------------------------------------------------85# Main prediction86# The main prediction is always at the end of the ret list.87rgb, unused_disp, unused_acc = ret[-1]88batch_pixels = model_utils.uint2float(batch.target_view.rgb)89loss = ((rgb - batch_pixels[Ellipsis, :3])**2).mean()90psnr = model_utils.compute_psnr(loss)91
92#------------------------------------------------------------------------93# Coarse / Regularization Prediction94if len(ret) > 1:95# If there are both coarse and fine predictions, we compute the loss for96# the coarse prediction (ret[0]) as well.97rgb_c, unused_disp_c, unused_acc_c = ret[0]98loss_c = ((rgb_c - batch_pixels[Ellipsis, :3])**2).mean()99psnr_c = model_utils.compute_psnr(loss_c)100else:101loss_c = 0.102psnr_c = 0.103
104#------------------------------------------------------------------------105# Weight Regularization106weight_penalty_params = jax.tree_leaves(variables["params"])107weight_l2 = sum(108[jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1])109weight_penalty = weight_decay * 0.5 * weight_l2110
111#------------------------------------------------------------------------112# Compute total loss and wrap the stats113total_loss = loss + loss_c + weight_penalty114stats = train_utils.Stats(115loss=loss, psnr=psnr, loss_c=loss_c, psnr_c=psnr_c, weight_l2=weight_l2)116return total_loss, stats117
118#------------------------------------------------------------------------119# Compute Graidents120grad_fn = jax.value_and_grad(loss_fn, has_aux=True)121(loss, stats), grad = grad_fn(state.params)122
123# Compute average gradient across multiple workers.124grad = jax.lax.pmean(grad, axis_name="batch")125
126#------------------------------------------------------------------------127# Update States128new_state = state.apply_gradients(grads=grad)129
130metrics_update = train_utils.TrainMetrics.gather_from_model_output(131total_loss=loss,132loss=stats.loss,133psnr=stats.psnr,134loss_c=stats.loss_c,135psnr_c=stats.psnr_c,136weight_l2=stats.weight_l2,137learning_rate=lr)138return new_state, metrics_update, rng139
140
141def eval_step(state, rng, batch,142render_pfn, config):143"""Compute the metrics for the given model in inference mode.144
145The model is applied to the inputs with train=False using all devices on the
146host. Afterwards metrics are averaged across *all* devices (of all hosts).
147Args:
148state: Replicate model state.
149rng: random number generator.
150batch: data_types.Batch. Inputs that should be evaluated.
151render_pfn: pmaped render function.
152config: exepriment config.
153
154Returns:
155Dictionary of the replicated metrics.
156"""
157logging.info("eval_step=================")158variables = {159"params": jax.device_get(jax.tree_map(lambda x: x[0], state)).params,160}161pred_color, pred_disp, pred_acc = render_utils.render_image(162functools.partial(render_pfn, variables),163batch,164rng,165render_utils.normalize_disp(config.dataset.name),166chunk=config.eval.chunk)167
168return pred_color, pred_disp, pred_acc169
170
171def train_and_evaluate(config, workdir):172"""Runs a training and evaluation loop.173
174Args:
175config: Configuration to use.
176workdir: Working directory for checkpoints and TF summaries. If this
177contains checkpoint training will be resumed from the latest checkpoint.
178"""
179if config.dataset.batch_size % jax.device_count() != 0:180raise ValueError("Batch size must be divisible by the number of devices.")181
182tf.io.gfile.makedirs(workdir)183# Deterministic training.184rng = jax.random.PRNGKey(config.seed)185# Shift the numpy random seed by process_index() to shuffle data loaded186# by different hosts187np.random.seed(20201473 + jax.process_index())188
189#----------------------------------------------------------------------------190# Build input pipeline.191rng, data_rng = jax.random.split(rng)192data_rng = jax.random.fold_in(data_rng, jax.process_index())193
194scene_path_list = train_utils.get_train_scene_list(config)195
196train_ds = datasets.create_train_dataset(config, scene_path_list[0])197_, eval_ds_dict = datasets.create_eval_dataset(config)198_, eval_ds = eval_ds_dict.popitem()199example_batch = train_ds.peek()200
201#----------------------------------------------------------------------------202# Learning rate schedule.203num_train_steps = config.train.max_steps204if num_train_steps == -1:205num_train_steps = train_ds.size()206steps_per_epoch = num_train_steps // config.train.num_epochs207logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,208steps_per_epoch)209
210learning_rate_fn = train_utils.create_learning_rate_fn(config)211
212#----------------------------------------------------------------------------213# Initialize model.214rng, model_rng = jax.random.split(rng)215model, state = models.create_train_state(216config,217model_rng,218learning_rate_fn=learning_rate_fn,219example_batch=example_batch,220)221
222#----------------------------------------------------------------------------223# Set up checkpointing of the model and the input pipeline.224
225# check if the job was stopped and relaunced226latest_ckpt = checkpoints.latest_checkpoint(workdir)227if latest_ckpt is None:228# No previous checkpoint. Then check for pretrained weights.229if config.train.pretrain_dir:230state = checkpoints.restore_checkpoint(config.train.pretrain_dir, state)231else:232state = checkpoints.restore_checkpoint(workdir, state)233
234initial_step = int(state.step) + 1235step_per_scene = config.train.switch_scene_iter236if config.dev_run:237jnp.set_printoptions(precision=2)238np.set_printoptions(precision=2)239step_per_scene = 3240
241#----------------------------------------------------------------------------242# Distribute training.243state = flax_utils.replicate(state)244p_train_step = jax.pmap(245functools.partial(246train_step,247model=model,248learning_rate_fn=learning_rate_fn,249weight_decay=config.train.weight_decay,250config=config,251),252axis_name="batch",253)254
255# Get distributed rendering function256render_pfn = render_utils.get_render_function(257model=model,258config=config,259randomized=False, # No randomization for evaluation.260)261
262#----------------------------------------------------------------------------263# Prepare Metric Writers264writer = metric_writers.create_default_writer(265workdir, just_logging=jax.process_index() > 0)266if initial_step == 1:267writer.write_hparams(dict(config))268
269logging.info("Starting training loop at step %d.", initial_step)270hooks = []271report_progress = periodic_actions.ReportProgress(272num_train_steps=num_train_steps, writer=writer)273if jax.process_index() == 0:274hooks += [275report_progress,276]277train_metrics = None278
279# Prefetch_buffer_size = 6 x batch_size280ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)281n_local_devices = jax.local_device_count()282rng = rng + jax.process_index() # Make random seed separate across hosts.283keys = jax.random.split(rng, n_local_devices) # For pmapping RNG keys.284
285with metric_writers.ensure_flushes(writer):286for step in range(initial_step, num_train_steps + 1):287# `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU288# devices.289if step % step_per_scene == 0:290scene_idx = np.random.randint(len(scene_path_list))291logging.info("Loading scene {}".format(scene_path_list[scene_idx])) # pylint: disable=logging-format-interpolation292curr_scene = scene_path_list[scene_idx]293if config.dataset.name == "dtu":294# lighting can take values between 0 and 6 (both included)295config.dataset.dtu_light_idx = np.random.randint(low=0, high=7)296train_ds = datasets.create_train_dataset(config, curr_scene)297ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)298
299is_last_step = step == num_train_steps300with jax.profiler.StepTraceAnnotation("train", step_num=step):301batch = next(ptrain_ds)302state, metrics_update, keys = p_train_step(303rng=keys, state=state, batch=batch)304metric_update = flax_utils.unreplicate(metrics_update)305train_metrics = (306metric_update
307if train_metrics is None else train_metrics.merge(metric_update))308# Quick indication that training is happening.309logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)310for h in hooks:311h(step)312
313if step % config.train.log_loss_every_steps == 0 or is_last_step:314writer.write_scalars(step, train_metrics.compute())315train_metrics = None316
317if step % config.train.render_every_steps == 0 or is_last_step:318test_batch = next(eval_ds)319test_pixels = model_utils.uint2float(320test_batch.target_view.rgb) # extract for evaluation321with report_progress.timed("eval"):322pred_color, pred_disp, pred_acc = eval_step(state, keys[0],323test_batch, render_pfn,324config)325#------------------------------------------------------------------326# Log metrics and images for host 0327#------------------------------------------------------------------328if jax.process_index() == 0:329psnr = model_utils.compute_psnr(330((pred_color - test_pixels)**2).mean())331ssim = 0.332writer.write_scalars(step, {333"train_eval/test_psnr": psnr,334"train_eval/test_ssim": ssim,335})336writer.write_images(337step, {338"test_pred_color": pred_color[None, :],339"test_target": test_pixels[None, :]340})341if pred_disp is not None:342writer.write_images(step, {"test_pred_disp": pred_disp[None, :]})343if pred_acc is not None:344writer.write_images(step, {"test_pred_acc": pred_acc[None, :]})345#------------------------------------------------------------------346
347if (jax.process_index()348== 0) and (step % config.train.checkpoint_every_steps == 0 or349is_last_step):350# Write final metrics to file351with file_utils.open_file(352os.path.join(workdir, "train_logs.json"), "w") as f:353log_dict = metric_update.compute()354for k, v in log_dict.items():355log_dict[k] = v.item()356f.write(json.dumps(log_dict))357with report_progress.timed("checkpoint"):358state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))359checkpoints.save_checkpoint(workdir, state_to_save, step, keep=100)360
361logging.info("Finishing training at step %d", num_train_steps)362