google-research
269 строк · 9.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""All functions related to loss computation and optimization.
17"""
18
19import flax20from flax.deprecated import nn21import jax22import jax.numpy as jnp23import jax.random as random24
25
26def get_optimizer(config):27optimizer = None28if config.optim.optimizer == 'Adam':29optimizer = flax.optim.Adam(beta1=config.optim.beta1, eps=config.optim.eps,30weight_decay=config.optim.weight_decay)31else:32raise NotImplementedError(33f'Optimizer {config.optim.optimizer} not supported yet!')34
35return optimizer36
37
38def optimization_manager(config):39"""Returns an optimize_fn based on config."""40def optimize(state,41grad,42warmup=config.optim.warmup,43grad_clip=config.optim.grad_clip):44"""Optimizes with warmup and gradient clipping (disabled if negative)."""45lr = state.lr46if warmup > 0:47lr = lr * jnp.minimum(state.step / warmup, 1.0)48if grad_clip >= 0:49# Compute global gradient norm50grad_norm = jnp.sqrt(51sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(grad)]))52# Clip gradient53clipped_grad = jax.tree_map(54lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip), grad)55else: # disabling gradient clipping if grad_clip < 056clipped_grad = grad57return state.optimizer.apply_gradient(clipped_grad, learning_rate=lr)58
59return optimize60
61
62def ncsn_loss(rng,63state,64batch,65sigmas,66continuous=False,67train=True,68optimize_fn=None,69anneal_power=2.,70loss_per_sigma=False,71class_conditional=False,72pmap_axis_name='batch'):73"""The objective function for NCSN.74
75Does one step of training or evaluation.
76Store EMA statistics during training and use EMA for evaluation.
77Will be called by jax.pmap using `pmap_axis_name`.
78
79Args:
80rng: a jax random state.
81state: a pytree of training states, including the optimizer, lr, etc.
82batch: a pytree of data points.
83sigmas: a numpy arrary representing the array of noise levels.
84continuous: Use a continuous distribution of sigmas and sample from it.
85train: True if we will train the model. Otherwise just do the evaluation.
86optimize_fn: takes state and grad and performs one optimization step.
87anneal_power: balancing losses of different noise levels. Defaults to 2.
88loss_per_sigma: return the loss for each sigma separately.
89class_conditional: train a score-based model conditioned on class labels.
90pmap_axis_name: the axis_name used when calling this function with pmap.
91
92Returns:
93loss, new_state if not loss_per_sigma. Otherwise return loss, new_state,
94losses, and used_sigmas. Here used_sigmas are noise levels sampled in this
95mini-batch, and `losses` contains the loss value for each datapoint and
96noise level.
97"""
98x = batch['image']99rng1, rng2 = random.split(rng)100if not continuous:101labels = random.choice(rng1, len(sigmas), shape=(x.shape[0],))102used_sigmas = sigmas[labels].reshape(103(x.shape[0], *([1] * len(x.shape[1:]))))104else:105labels = random.uniform(106rng1, (x.shape[0],),107minval=jnp.log(sigmas[-1]),108maxval=jnp.log(sigmas[0]))109labels = jnp.exp(labels)110used_sigmas = labels.reshape((x.shape[0], *([1] * len(x.shape[1:]))))111
112if class_conditional:113class_labels = batch['label']114
115noise = random.normal(rng2, x.shape) * used_sigmas116perturbed_data = noise + x117
118run_rng, _ = random.split(rng2)119@jax.jit120def loss_fn(model):121if train:122with nn.stateful(state.model_state) as new_model_state:123with nn.stochastic(run_rng):124if not class_conditional:125scores = model(perturbed_data, labels, train=train)126else:127scores = model(perturbed_data, labels, y=class_labels, train=train)128else:129with nn.stateful(state.model_state, mutable=False):130with nn.stochastic(run_rng):131if not class_conditional:132scores = model(perturbed_data, labels, train=train)133else:134scores = model(perturbed_data, labels, y=class_labels, train=train)135
136new_model_state = state.model_state137
138scores = scores.reshape((scores.shape[0], -1))139target = -1 / (used_sigmas ** 2) * noise140target = target.reshape((target.shape[0], -1))141losses = 1 / 2. * ((scores - target)**1422).sum(axis=-1) * used_sigmas.squeeze()**anneal_power143loss = jnp.mean(losses)144
145if loss_per_sigma:146return loss, new_model_state, losses147else:148return loss, new_model_state149
150if train:151grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))152if loss_per_sigma:153(loss, new_model_state, losses), grad = grad_fn(state.optimizer.target)154else:155(loss, new_model_state), grad = grad_fn(state.optimizer.target)156grad = jax.lax.pmean(grad, axis_name=pmap_axis_name)157new_optimizer = optimize_fn(state, grad)158new_params_ema = jax.tree_map(159lambda p_ema, p: p_ema * state.ema_rate + p * (1. - state.ema_rate),160state.params_ema, new_optimizer.target.params)161step = state.step + 1162new_state = state.replace( # pytype: disable=attribute-error163step=step,164optimizer=new_optimizer,165model_state=new_model_state,166params_ema=new_params_ema)167else:168model_ema = state.optimizer.target.replace(params=state.params_ema)169if loss_per_sigma:170loss, _, losses = loss_fn(model_ema) # pytype: disable=bad-unpacking171else:172loss, *_ = loss_fn(model_ema)173
174new_state = state175
176loss = jax.lax.pmean(loss, axis_name=pmap_axis_name)177if loss_per_sigma:178return loss, new_state, losses, used_sigmas.squeeze()179else:180return loss, new_state181
182
183def ddpm_loss(rng,184state,185batch,186ddpm_params,187train=True,188optimize_fn=None,189pmap_axis_name='batch'):190"""The objective function for DDPM.191
192Same as NCSN but with different noise perturbations. Mostly copied
193from https://github.com/hojonathanho/diffusion.
194
195Does one step of training or evaluation.
196Store EMA statistics during training and evaluate with EMA.
197Will be called by jax.pmap using `pmap_axis_name`.
198
199Args:
200rng: a jax random state.
201state: a pytree of training states, including the optimizer, lr, etc.
202batch: a pytree of data points.
203ddpm_params: a dictionary containing betas, alphas, and others.
204train: True if we will train the model. Otherwise just do the evaluation.
205optimize_fn: takes state and grad and performs one optimization step.
206pmap_axis_name: the axis_name used when calling this function with pmap.
207
208Returns:
209loss, new_state
210"""
211
212x = batch['image']213rng1, rng2 = random.split(rng)214betas = jnp.asarray(ddpm_params['betas'], dtype=jnp.float32)215sqrt_alphas_cumprod = jnp.asarray(216ddpm_params['sqrt_alphas_cumprod'], dtype=jnp.float32)217sqrt_1m_alphas_cumprod = jnp.asarray(218ddpm_params['sqrt_1m_alphas_cumprod'], dtype=jnp.float32)219T = random.choice(rng1, len(betas), shape=(x.shape[0],)) # pylint: disable=invalid-name220
221noise = random.normal(rng2, x.shape)222
223perturbed_data = sqrt_alphas_cumprod[T, None, None, None] * x + \224sqrt_1m_alphas_cumprod[T, None, None, None] * noise225
226run_rng, _ = random.split(rng2)227
228@jax.jit229def loss_fn(model):230if train:231with nn.stateful(state.model_state) as new_model_state:232with nn.stochastic(run_rng):233scores = model(perturbed_data, T, train=train)234else:235with nn.stateful(state.model_state, mutable=False):236with nn.stochastic(run_rng):237scores = model(perturbed_data, T, train=train)238
239new_model_state = state.model_state240
241scores = scores.reshape((scores.shape[0], -1))242target = noise.reshape((noise.shape[0], -1))243loss = jnp.mean((scores - target)**2)244return loss, new_model_state245
246if train:247grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))248(loss, new_model_state), grad = grad_fn(state.optimizer.target)249grad = jax.lax.pmean(grad, axis_name=pmap_axis_name)250## WARNING: the gradient clip step differs slightly from the251## original DDPM implementation, and seem to be more reasonable.252## The impact of this difference on performance is negligible.253new_optimizer = optimize_fn(state, grad)254new_params_ema = jax.tree_map(255lambda p_ema, p: p_ema * state.ema_rate + p * (1. - state.ema_rate),256state.params_ema, new_optimizer.target.params)257step = state.step + 1258new_state = state.replace( # pytype: disable=attribute-error259step=step,260optimizer=new_optimizer,261model_state=new_model_state,262params_ema=new_params_ema)263else:264model_ema = state.optimizer.target.replace(params=state.params_ema)265loss, _ = loss_fn(model_ema)266new_state = state267
268loss = jax.lax.pmean(loss, axis_name=pmap_axis_name)269return loss, new_state270