google-research
552 строки · 21.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: skip-file
17"""Training and evalution for score-based generative models."""
18
19import functools20import gc21import io22import os23import time24from typing import Any25
26from . import datasets27from . import evaluation28from . import losses29from . import models # Keep this import for registering all model definitions.30from . import sampling31from . import utils32from .models import utils as mutils33from absl import logging34import flax35import flax.deprecated.nn as nn36import flax.jax_utils as flax_utils37from flax.metrics import tensorboard38from flax.training import checkpoints39import jax40import jax.numpy as jnp41import ml_collections42from .models import ddpm, ncsnv2, ncsnv343import numpy as np44import tensorflow as tf45import tensorflow_gan as tfgan46
47
48def train(config, workdir):49"""Runs a training loop.50
51Args:
52config: Configuration to use.
53workdir: Working directory for checkpoints and TF summaries. If this
54contains checkpoint training will be resumed from the latest checkpoint.
55"""
56
57# Create directories for experimental logs58tf.io.gfile.makedirs(workdir)59sample_dir = os.path.join(workdir, "samples")60tf.io.gfile.makedirs(sample_dir)61rng = jax.random.PRNGKey(config.seed)62tb_dir = os.path.join(workdir, "tensorboard")63tf.io.gfile.makedirs(tb_dir)64if jax.host_id() == 0:65writer = tensorboard.SummaryWriter(tb_dir)66
67# Initialize model.68rng, model_rng = jax.random.split(rng)69model_name = config.model.name70ncsn_def = mutils.get_model(model_name).partial(config=config)71rng, run_rng = jax.random.split(rng)72# Whether the generative model is conditioned on class labels73class_conditional = "conditional" in config.training.loss.lower()74with nn.stateful() as init_model_state:75with nn.stochastic(run_rng):76input_shape = (jax.local_device_count(), config.data.image_size,77config.data.image_size, 3)78input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)]79if class_conditional:80input_list.append(input_list[-1])81_, initial_params = ncsn_def.init_by_shape(82model_rng, input_list, train=True)83ncsn = nn.Model(ncsn_def, initial_params)84
85optimizer = losses.get_optimizer(config).create(ncsn)86
87state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,88model_state=init_model_state,89ema_rate=config.model.ema_rate,90params_ema=initial_params,91rng=rng) # pytype: disable=wrong-keyword-args92
93del ncsn, init_model_state # Do not keep a copy of the initial model.94
95# Create checkpoints directory and the initial checkpoint96checkpoint_dir = os.path.join(workdir, "checkpoints")97ckpt = utils.Checkpoint(98checkpoint_dir,99max_to_keep=None)100ckpt.restore_or_initialize(state)101
102# Save intermediate checkpoints to resume training automatically103checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")104ckpt_meta = utils.Checkpoint(105checkpoint_meta_dir,106max_to_keep=1)107state = ckpt_meta.restore_or_initialize(state)108initial_step = int(state.step)109rng = state.rng110
111# Build input pipeline.112rng, ds_rng = jax.random.split(rng)113train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config)114train_iter = iter(train_ds) # pytype: disable=wrong-arg-types115eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types116scaler = datasets.get_data_scaler(config) # data normalizer117inverse_scaler = datasets.get_data_inverse_scaler(config)118
119# Distribute training.120optimize_fn = losses.optimization_manager(config)121if config.training.loss.lower() == "ddpm":122# Use score matching loss with DDPM-type perturbation.123ddpm_params = mutils.get_ddpm_params()124train_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,125train=True, optimize_fn=optimize_fn)126eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,127train=False)128else:129# Use score matching loss with NCSN-type perturbation.130sigmas = mutils.get_sigmas(config)131# Whether to use a continuous distribution of noise levels132continuous = "continuous" in config.training.loss.lower()133train_step = functools.partial(134losses.ncsn_loss,135sigmas=sigmas,136class_conditional=class_conditional,137continuous=continuous,138train=True,139optimize_fn=optimize_fn,140anneal_power=config.training.anneal_power)141eval_step = functools.partial(142losses.ncsn_loss,143sigmas=sigmas,144class_conditional=class_conditional,145continuous=continuous,146train=False,147anneal_power=config.training.anneal_power)148
149p_train_step = jax.pmap(train_step, axis_name="batch")150p_eval_step = jax.pmap(eval_step, axis_name="batch")151state = flax_utils.replicate(state)152
153num_train_steps = config.training.n_iters154
155logging.info("Starting training loop at step %d.", initial_step)156rng = jax.random.fold_in(rng, jax.host_id())157for step in range(initial_step, num_train_steps + 1):158# `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU159# devices.160
161# Convert data to JAX arrays. Use ._numpy() to avoid copy.162batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter)) # pylint: disable=protected-access163
164rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)165next_rng = jnp.asarray(next_rng)166loss, state = p_train_step(next_rng, state, batch)167loss = flax.jax_utils.unreplicate(loss)168
169# Quick indication that training is happening.170logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)171
172if jax.host_id() == 0 and step % 50 == 0:173logging.info("step: %d, training_loss: %.5e", step, loss)174writer.scalar("training_loss", loss, step)175
176# Save a temporary checkpoint to resume training after pre-emption.177if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id(178) == 0:179saved_state = flax_utils.unreplicate(state)180saved_state = saved_state.replace(rng=rng)181ckpt_meta.save(saved_state)182
183# Report the loss on an evaluation dataset.184if step % 100 == 0:185rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)186next_rng = jnp.asarray(next_rng)187eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter)) # pylint: disable=protected-access188eval_loss, _ = p_eval_step(next_rng, state, eval_batch)189eval_loss = flax.jax_utils.unreplicate(eval_loss)190if jax.host_id() == 0:191logging.info("step: %d, eval_loss: %.5e", step, eval_loss)192writer.scalar("eval_loss", eval_loss, step)193
194# Save a checkpoint periodically and generate samples.195if (step +1961) % config.training.snapshot_freq == 0 or step == num_train_steps:197# Save the checkpoint.198if jax.host_id() == 0:199saved_state = flax_utils.unreplicate(state)200saved_state = saved_state.replace(rng=rng)201ckpt.save(saved_state)202
203# Generate and save samples204if config.training.snapshot_sampling:205rng, sample_rng = jax.random.split(rng)206init_shape = tuple(train_ds.element_spec["image"].shape)207samples = sampling.get_samples(sample_rng,208config,209flax_utils.unreplicate(state),210init_shape,211scaler,212inverse_scaler,213class_conditional=class_conditional)214this_sample_dir = os.path.join(215sample_dir, "iter_{}_host_{}".format(step, jax.host_id()))216tf.io.gfile.makedirs(this_sample_dir)217
218if config.sampling.final_only: # Do not save intermediate samples219sample = samples[-1]220image_grid = sample.reshape((-1, *sample.shape[2:]))221nrow = int(np.sqrt(image_grid.shape[0]))222sample = np.clip(sample * 255, 0, 255).astype(np.uint8)223with tf.io.gfile.GFile(224os.path.join(this_sample_dir, "sample.np"), "wb") as fout:225np.save(fout, sample)226
227with tf.io.gfile.GFile(228os.path.join(this_sample_dir, "sample.png"), "wb") as fout:229utils.save_image(image_grid, fout, nrow=nrow, padding=2)230else: # Save all intermediate samples produced during sampling.231for i, sample in enumerate(samples):232image_grid = sample.reshape((-1, *sample.shape[2:]))233nrow = int(np.sqrt(image_grid.shape[0]))234sample = np.clip(sample * 255, 0, 255).astype(np.uint8)235with tf.io.gfile.GFile(236os.path.join(this_sample_dir, "sample_{}.np".format(i)),237"wb") as fout:238np.save(fout, sample)239
240with tf.io.gfile.GFile(241os.path.join(this_sample_dir, "sample_{}.png".format(i)),242"wb") as fout:243utils.save_image(image_grid, fout, nrow=nrow, padding=2)244
245
246def evaluate(config,247workdir,248eval_folder = "eval"):249"""Evaluate trained models.250
251Args:
252config: Configuration to use.
253workdir: Working directory for checkpoints.
254eval_folder: The subfolder for storing evaluation results. Default to
255"eval".
256"""
257# Create eval_dir258eval_dir = os.path.join(workdir, eval_folder)259tf.io.gfile.makedirs(eval_dir)260
261rng = jax.random.PRNGKey(config.seed + 1)262
263# Build input pipeline.264rng, ds_rng = jax.random.split(rng)265_, eval_ds, _ = datasets.get_dataset(ds_rng, config, evaluation=True)266scaler = datasets.get_data_scaler(config)267inverse_scaler = datasets.get_data_inverse_scaler(config)268
269# Initialize model.270rng, model_rng = jax.random.split(rng)271model_name = config.model.name272ncsn_def = mutils.get_model(model_name).partial(config=config)273rng, run_rng = jax.random.split(rng)274class_conditional = "conditional" in config.training.loss.lower()275with nn.stateful() as init_model_state:276with nn.stochastic(run_rng):277input_shape = tuple(eval_ds.element_spec["image"].shape[1:])278input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)]279if class_conditional:280input_list.append(input_list[-1])281_, initial_params = ncsn_def.init_by_shape(282model_rng, input_list, train=True)283ncsn = nn.Model(ncsn_def, initial_params)284
285optimizer = losses.get_optimizer(config).create(ncsn)286state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,287model_state=init_model_state,288ema_rate=config.model.ema_rate,289params_ema=initial_params,290rng=rng) # pytype: disable=wrong-keyword-args291
292del ncsn, init_model_state # Do not keep a copy of the initial model.293
294checkpoint_dir = os.path.join(workdir, "checkpoints")295if config.training.loss.lower() == "ddpm":296# Use the score matching loss with DDPM-type perturbation.297ddpm_params = mutils.get_ddpm_params()298eval_step = functools.partial(299losses.ddpm_loss, ddpm_params=ddpm_params, train=False)300else:301# Use the score matching loss with NCSN-type perturbation.302sigmas = mutils.get_sigmas(config)303continuous = "continuous" in config.training.loss.lower()304eval_step = functools.partial(305losses.ncsn_loss,306sigmas=sigmas,307continuous=continuous,308class_conditional=class_conditional,309train=False,310anneal_power=config.training.anneal_power)311
312p_eval_step = jax.pmap(eval_step, axis_name="batch")313
314rng = jax.random.fold_in(rng, jax.host_id())315
316# A data class for checkpointing.317@flax.struct.dataclass318class EvalMeta:319ckpt_id: int320round_id: int321rng: Any322
323# Add one additional round to get the exact number of samples as required.324num_rounds = config.eval.num_samples // config.eval.batch_size + 1325
326eval_meta = EvalMeta(ckpt_id=config.eval.begin_ckpt, round_id=-1, rng=rng)327eval_meta = checkpoints.restore_checkpoint(328eval_dir, eval_meta, step=None, prefix=f"meta_{jax.host_id()}_")329
330if eval_meta.round_id < num_rounds - 1:331begin_ckpt = eval_meta.ckpt_id332begin_round = eval_meta.round_id + 1333else:334begin_ckpt = eval_meta.ckpt_id + 1335begin_round = 0336
337rng = eval_meta.rng338# Use inceptionV3 for images with higher resolution339inceptionv3 = config.data.image_size >= 256340inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)341
342logging.info("begin checkpoint: %d", begin_ckpt)343for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):344ckpt_filename = os.path.join(checkpoint_dir, "ckpt-{}.flax".format(ckpt))345
346# Wait if the target checkpoint hasn't been produced yet.347waiting_message_printed = False348while not tf.io.gfile.exists(ckpt_filename):349if not waiting_message_printed and jax.host_id() == 0:350logging.warn("Waiting for the arrival of ckpt-%d.flax", ckpt)351waiting_message_printed = True352time.sleep(10)353
354# In case the file was just written and not ready to read from yet.355try:356state = utils.load_state_dict(ckpt_filename, state)357except:358time.sleep(60)359try:360state = utils.load_state_dict(ckpt_filename, state)361except:362time.sleep(120)363state = utils.load_state_dict(ckpt_filename, state)364
365pstate = flax.jax_utils.replicate(state)366eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types367
368# Compute the loss function on the full evaluation dataset.369all_losses = []370for i, batch in enumerate(eval_iter):371rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)372next_rng = jnp.asarray(next_rng)373eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), batch) # pylint: disable=protected-access374eval_loss, _ = p_eval_step(next_rng, pstate, eval_batch)375eval_loss = flax.jax_utils.unreplicate(eval_loss)376all_losses.append(eval_loss)377if (i + 1) % 1000 == 0 and jax.host_id() == 0:378logging.info("Finished %dth step loss evaluation", i + 1)379
380all_losses = jnp.asarray(all_losses)381
382state = jax.device_put(state)383# Sampling and computing statistics for Inception scores, FIDs, and KIDs.384# Designed to be pre-emption safe. Automatically resumes when interrupted.385for r in range(begin_round, num_rounds):386if jax.host_id() == 0:387logging.info("sampling -- ckpt: %d, round: %d", ckpt, r)388rng, sample_rng = jax.random.split(rng)389init_shape = tuple(eval_ds.element_spec["image"].shape)390
391this_sample_dir = os.path.join(392eval_dir, f"ckpt_{ckpt}_host_{jax.host_id()}")393tf.io.gfile.makedirs(this_sample_dir)394samples = sampling.get_samples(sample_rng, config, state, init_shape,395scaler, inverse_scaler,396class_conditional=class_conditional)397samples = samples[-1]398samples = np.clip(samples * 255., 0, 255).astype(np.uint8)399samples = samples.reshape(400(-1, config.data.image_size, config.data.image_size, 3))401with tf.io.gfile.GFile(402os.path.join(this_sample_dir, f"samples_{r}.npz"), "wb") as fout:403io_buffer = io.BytesIO()404np.savez_compressed(io_buffer, samples=samples)405fout.write(io_buffer.getvalue())406
407gc.collect()408latents = evaluation.run_inception_distributed(samples, inception_model,409inceptionv3=inceptionv3)410gc.collect()411with tf.io.gfile.GFile(412os.path.join(this_sample_dir, f"statistics_{r}.npz"), "wb") as fout:413io_buffer = io.BytesIO()414np.savez_compressed(415io_buffer, pool_3=latents["pool_3"], logits=latents["logits"])416fout.write(io_buffer.getvalue())417
418eval_meta = eval_meta.replace(ckpt_id=ckpt, round_id=r, rng=rng)419# Save an intermediate checkpoint directly if not the last round.420# Otherwise save eval_meta after computing the Inception scores and FIDs421if r < num_rounds - 1:422checkpoints.save_checkpoint(423eval_dir,424eval_meta,425step=ckpt * num_rounds + r,426keep=1,427prefix=f"meta_{jax.host_id()}_")428
429# Compute inception scores, FIDs and KIDs.430if jax.host_id() == 0:431# Load all statistics that have been previously computed and saved.432all_logits = []433all_pools = []434for host in range(jax.host_count()):435this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}_host_{host}")436
437stats = tf.io.gfile.glob(438os.path.join(this_sample_dir, "statistics_*.npz"))439wait_message = False440while len(stats) < num_rounds:441if not wait_message:442logging.warn("Waiting for statistics on host %d", host)443wait_message = True444stats = tf.io.gfile.glob(445os.path.join(this_sample_dir, "statistics_*.npz"))446time.sleep(1)447
448for stat_file in stats:449with tf.io.gfile.GFile(stat_file, "rb") as fin:450stat = np.load(fin)451if not inceptionv3:452all_logits.append(stat["logits"])453all_pools.append(stat["pool_3"])454
455if not inceptionv3:456all_logits = np.concatenate(457all_logits, axis=0)[:config.eval.num_samples]458all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples]459
460# Load pre-computed dataset statistics.461data_stats = evaluation.load_dataset_stats(config)462data_pools = data_stats["pool_3"]463
464if hasattr(config.eval, "num_partitions"):465# Divide samples into several partitions and compute FID/KID/IS on them.466assert not inceptionv3467fids = []468kids = []469inception_scores = []470partition_size = config.eval.num_samples // config.eval.num_partitions471tf_data_pools = tf.convert_to_tensor(data_pools)472for i in range(config.eval.num_partitions):473this_pools = all_pools[i * partition_size:(i + 1) * partition_size]474this_logits = all_logits[i * partition_size:(i + 1) * partition_size]475inception_scores.append(476tfgan.eval.classifier_score_from_logits(this_logits))477fids.append(478tfgan.eval.frechet_classifier_distance_from_activations(479data_pools, this_pools))480this_pools = tf.convert_to_tensor(this_pools)481kids.append(482tfgan.eval.kernel_classifier_distance_from_activations(483tf_data_pools, this_pools).numpy())484
485fids = np.asarray(fids)486inception_scores = np.asarray(inception_scores)487kids = np.asarray(kids)488with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_all_{ckpt}.npz"),489"wb") as f:490io_buffer = io.BytesIO()491np.savez_compressed(492io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(),493ISs=inception_scores, fids=fids, kids=kids)494f.write(io_buffer.getvalue())495
496else:497# Compute FID/KID/IS on all samples together.498if not inceptionv3:499inception_score = tfgan.eval.classifier_score_from_logits(all_logits)500else:501inception_score = -1502
503fid = tfgan.eval.frechet_classifier_distance_from_activations(504data_pools, all_pools)505# Hack to get tfgan KID work for eager execution.506tf_data_pools = tf.convert_to_tensor(data_pools)507tf_all_pools = tf.convert_to_tensor(all_pools)508kid = tfgan.eval.kernel_classifier_distance_from_activations(509tf_data_pools, tf_all_pools).numpy()510del tf_data_pools, tf_all_pools511
512logging.info(513"ckpt-%d --- loss: %.6e, inception_score: %.6e, FID: %.6e, KID: %.6e",514ckpt, all_losses.mean(), inception_score, fid, kid)515
516with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_{ckpt}.npz"),517"wb") as f:518io_buffer = io.BytesIO()519np.savez_compressed(520io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(),521IS=inception_score, fid=fid, kid=kid)522f.write(io_buffer.getvalue())523else:524# For host_id() != 0.525# Use file existence to emulate synchronization across hosts.526if hasattr(config.eval, "num_partitions"):527assert not inceptionv3528while not tf.io.gfile.exists(529os.path.join(eval_dir, f"report_all_{ckpt}.npz")):530time.sleep(1.)531
532else:533while not tf.io.gfile.exists(534os.path.join(eval_dir, f"report_{ckpt}.npz")):535time.sleep(1.)536
537# Save eval_meta after computing IS/KID/FID to mark the end of evaluation538# for this checkpoint.539checkpoints.save_checkpoint(540eval_dir,541eval_meta,542step=ckpt * num_rounds + r,543keep=1,544prefix=f"meta_{jax.host_id()}_")545
546begin_round = 0547
548# Remove all meta files after finishing evaluation.549meta_files = tf.io.gfile.glob(550os.path.join(eval_dir, f"meta_{jax.host_id()}_*"))551for file in meta_files:552tf.io.gfile.remove(file)553