google-research

Форк
0
269 строк · 9.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""All functions related to loss computation and optimization.
17
"""
18

19
import flax
20
from flax.deprecated import nn
21
import jax
22
import jax.numpy as jnp
23
import jax.random as random
24

25

26
def get_optimizer(config):
27
  optimizer = None
28
  if config.optim.optimizer == 'Adam':
29
    optimizer = flax.optim.Adam(beta1=config.optim.beta1, eps=config.optim.eps,
30
                                weight_decay=config.optim.weight_decay)
31
  else:
32
    raise NotImplementedError(
33
        f'Optimizer {config.optim.optimizer} not supported yet!')
34

35
  return optimizer
36

37

38
def optimization_manager(config):
39
  """Returns an optimize_fn based on config."""
40
  def optimize(state,
41
               grad,
42
               warmup=config.optim.warmup,
43
               grad_clip=config.optim.grad_clip):
44
    """Optimizes with warmup and gradient clipping (disabled if negative)."""
45
    lr = state.lr
46
    if warmup > 0:
47
      lr = lr * jnp.minimum(state.step / warmup, 1.0)
48
    if grad_clip >= 0:
49
      # Compute global gradient norm
50
      grad_norm = jnp.sqrt(
51
          sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(grad)]))
52
      # Clip gradient
53
      clipped_grad = jax.tree_map(
54
          lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip), grad)
55
    else:  # disabling gradient clipping if grad_clip < 0
56
      clipped_grad = grad
57
    return state.optimizer.apply_gradient(clipped_grad, learning_rate=lr)
58

59
  return optimize
60

61

62
def ncsn_loss(rng,
63
              state,
64
              batch,
65
              sigmas,
66
              continuous=False,
67
              train=True,
68
              optimize_fn=None,
69
              anneal_power=2.,
70
              loss_per_sigma=False,
71
              class_conditional=False,
72
              pmap_axis_name='batch'):
73
  """The objective function for NCSN.
74

75
  Does one step of training or evaluation.
76
  Store EMA statistics during training and use EMA for evaluation.
77
  Will be called by jax.pmap using `pmap_axis_name`.
78

79
  Args:
80
    rng: a jax random state.
81
    state: a pytree of training states, including the optimizer, lr, etc.
82
    batch: a pytree of data points.
83
    sigmas: a numpy arrary representing the array of noise levels.
84
    continuous: Use a continuous distribution of sigmas and sample from it.
85
    train: True if we will train the model. Otherwise just do the evaluation.
86
    optimize_fn: takes state and grad and performs one optimization step.
87
    anneal_power: balancing losses of different noise levels. Defaults to 2.
88
    loss_per_sigma: return the loss for each sigma separately.
89
    class_conditional: train a score-based model conditioned on class labels.
90
    pmap_axis_name: the axis_name used when calling this function with pmap.
91

92
  Returns:
93
    loss, new_state if not loss_per_sigma. Otherwise return loss, new_state,
94
    losses, and used_sigmas. Here used_sigmas are noise levels sampled in this
95
    mini-batch, and `losses` contains the loss value for each datapoint and
96
    noise level.
97
  """
98
  x = batch['image']
99
  rng1, rng2 = random.split(rng)
100
  if not continuous:
101
    labels = random.choice(rng1, len(sigmas), shape=(x.shape[0],))
102
    used_sigmas = sigmas[labels].reshape(
103
        (x.shape[0], *([1] * len(x.shape[1:]))))
104
  else:
105
    labels = random.uniform(
106
        rng1, (x.shape[0],),
107
        minval=jnp.log(sigmas[-1]),
108
        maxval=jnp.log(sigmas[0]))
109
    labels = jnp.exp(labels)
110
    used_sigmas = labels.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
111

112
  if class_conditional:
113
    class_labels = batch['label']
114

115
  noise = random.normal(rng2, x.shape) * used_sigmas
116
  perturbed_data = noise + x
117

118
  run_rng, _ = random.split(rng2)
119
  @jax.jit
120
  def loss_fn(model):
121
    if train:
122
      with nn.stateful(state.model_state) as new_model_state:
123
        with nn.stochastic(run_rng):
124
          if not class_conditional:
125
            scores = model(perturbed_data, labels, train=train)
126
          else:
127
            scores = model(perturbed_data, labels, y=class_labels, train=train)
128
    else:
129
      with nn.stateful(state.model_state, mutable=False):
130
        with nn.stochastic(run_rng):
131
          if not class_conditional:
132
            scores = model(perturbed_data, labels, train=train)
133
          else:
134
            scores = model(perturbed_data, labels, y=class_labels, train=train)
135

136
      new_model_state = state.model_state
137

138
    scores = scores.reshape((scores.shape[0], -1))
139
    target = -1 / (used_sigmas ** 2) * noise
140
    target = target.reshape((target.shape[0], -1))
141
    losses = 1 / 2. * ((scores - target)**
142
                       2).sum(axis=-1) * used_sigmas.squeeze()**anneal_power
143
    loss = jnp.mean(losses)
144

145
    if loss_per_sigma:
146
      return loss, new_model_state, losses
147
    else:
148
      return loss, new_model_state
149

150
  if train:
151
    grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))
152
    if loss_per_sigma:
153
      (loss, new_model_state, losses), grad = grad_fn(state.optimizer.target)
154
    else:
155
      (loss, new_model_state), grad = grad_fn(state.optimizer.target)
156
    grad = jax.lax.pmean(grad, axis_name=pmap_axis_name)
157
    new_optimizer = optimize_fn(state, grad)
158
    new_params_ema = jax.tree_map(
159
        lambda p_ema, p: p_ema * state.ema_rate + p * (1. - state.ema_rate),
160
        state.params_ema, new_optimizer.target.params)
161
    step = state.step + 1
162
    new_state = state.replace(  # pytype: disable=attribute-error
163
        step=step,
164
        optimizer=new_optimizer,
165
        model_state=new_model_state,
166
        params_ema=new_params_ema)
167
  else:
168
    model_ema = state.optimizer.target.replace(params=state.params_ema)
169
    if loss_per_sigma:
170
      loss, _, losses = loss_fn(model_ema)  # pytype: disable=bad-unpacking
171
    else:
172
      loss, *_ = loss_fn(model_ema)
173

174
    new_state = state
175

176
  loss = jax.lax.pmean(loss, axis_name=pmap_axis_name)
177
  if loss_per_sigma:
178
    return loss, new_state, losses, used_sigmas.squeeze()
179
  else:
180
    return loss, new_state
181

182

183
def ddpm_loss(rng,
184
              state,
185
              batch,
186
              ddpm_params,
187
              train=True,
188
              optimize_fn=None,
189
              pmap_axis_name='batch'):
190
  """The objective function for DDPM.
191

192
  Same as NCSN but with different noise perturbations. Mostly copied
193
  from https://github.com/hojonathanho/diffusion.
194

195
  Does one step of training or evaluation.
196
  Store EMA statistics during training and evaluate with EMA.
197
  Will be called by jax.pmap using `pmap_axis_name`.
198

199
  Args:
200
    rng: a jax random state.
201
    state: a pytree of training states, including the optimizer, lr, etc.
202
    batch: a pytree of data points.
203
    ddpm_params: a dictionary containing betas, alphas, and others.
204
    train: True if we will train the model. Otherwise just do the evaluation.
205
    optimize_fn: takes state and grad and performs one optimization step.
206
    pmap_axis_name: the axis_name used when calling this function with pmap.
207

208
  Returns:
209
    loss, new_state
210
  """
211

212
  x = batch['image']
213
  rng1, rng2 = random.split(rng)
214
  betas = jnp.asarray(ddpm_params['betas'], dtype=jnp.float32)
215
  sqrt_alphas_cumprod = jnp.asarray(
216
      ddpm_params['sqrt_alphas_cumprod'], dtype=jnp.float32)
217
  sqrt_1m_alphas_cumprod = jnp.asarray(
218
      ddpm_params['sqrt_1m_alphas_cumprod'], dtype=jnp.float32)
219
  T = random.choice(rng1, len(betas), shape=(x.shape[0],))  # pylint: disable=invalid-name
220

221
  noise = random.normal(rng2, x.shape)
222

223
  perturbed_data = sqrt_alphas_cumprod[T, None, None, None] * x + \
224
      sqrt_1m_alphas_cumprod[T, None, None, None] * noise
225

226
  run_rng, _ = random.split(rng2)
227

228
  @jax.jit
229
  def loss_fn(model):
230
    if train:
231
      with nn.stateful(state.model_state) as new_model_state:
232
        with nn.stochastic(run_rng):
233
          scores = model(perturbed_data, T, train=train)
234
    else:
235
      with nn.stateful(state.model_state, mutable=False):
236
        with nn.stochastic(run_rng):
237
          scores = model(perturbed_data, T, train=train)
238

239
      new_model_state = state.model_state
240

241
    scores = scores.reshape((scores.shape[0], -1))
242
    target = noise.reshape((noise.shape[0], -1))
243
    loss = jnp.mean((scores - target)**2)
244
    return loss, new_model_state
245

246
  if train:
247
    grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))
248
    (loss, new_model_state), grad = grad_fn(state.optimizer.target)
249
    grad = jax.lax.pmean(grad, axis_name=pmap_axis_name)
250
    ## WARNING: the gradient clip step differs slightly from the
251
    ## original DDPM implementation, and seem to be more reasonable.
252
    ## The impact of this difference on performance is negligible.
253
    new_optimizer = optimize_fn(state, grad)
254
    new_params_ema = jax.tree_map(
255
        lambda p_ema, p: p_ema * state.ema_rate + p * (1. - state.ema_rate),
256
        state.params_ema, new_optimizer.target.params)
257
    step = state.step + 1
258
    new_state = state.replace(  # pytype: disable=attribute-error
259
        step=step,
260
        optimizer=new_optimizer,
261
        model_state=new_model_state,
262
        params_ema=new_params_ema)
263
  else:
264
    model_ema = state.optimizer.target.replace(params=state.params_ema)
265
    loss, _ = loss_fn(model_ema)
266
    new_state = state
267

268
  loss = jax.lax.pmean(loss, axis_name=pmap_axis_name)
269
  return loss, new_state
270

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.