google-research

Форк
0
/
train_gd_clip_grad.py 
206 строк · 6.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Main train loop for DP-Adam. File intended to be mostly self-contained."""
17

18
import functools
19

20
from clu import metric_writers
21
from flax import jax_utils
22
import jax
23
import jax.numpy as jnp
24
import jax.profiler
25
import ml_collections
26
import numpy as np
27
import tensorflow as tf
28
from tensorflow_privacy.privacy.analysis import compute_noise_from_budget_lib
29

30
from dp_transfer import data_utils
31
from dp_transfer import dataset
32
from dp_transfer import utils
33

34

35

36
def unreplicate_and_get(x):
37
  return jax.device_get(jax_utils.unreplicate(x))
38

39

40
def noisy(step, x, s, key):
41
  if 0 < s < np.inf:
42
    new_key = jax.random.fold_in(key, step)
43
    noise = jax.random.normal(new_key, shape=jnp.shape(x)) * s
44
    return x + noise
45
  return x
46

47

48
def one_hot(a, num_classes):
49
  return np.squeeze(np.eye(num_classes)[a.reshape(-1)])
50

51

52
def log_likelihood(weights, data, labels, bias):
53
  """Normalized negative log likelihood."""
54
  logits = jnp.einsum('d,ld->l', data, weights) + bias
55
  log_p, log_not_p = jax.nn.log_sigmoid(logits), jax.nn.log_sigmoid(-logits)
56

57
  loss = -((labels * log_p) + (1. - labels) * log_not_p)
58
  return jnp.mean(loss)
59

60

61
def log_likelihood_gradient(weights, data, labels, bias):
62
  """Gradient of negative log likelihood."""
63
  return jax.grad(lambda w: log_likelihood(w, data, labels, bias))(weights)
64

65

66
def clip(x, clip_norm=1.0):
67
  divisor = jnp.maximum(jnp.linalg.norm(x) / clip_norm, 1.)
68
  return x / divisor
69

70

71
def clipped_log_likelihood_gradient(weights, data, labels, bias, clip_norm):
72
  """Gradient of negative log likelihood."""
73
  gradi = log_likelihood_gradient(weights, data, labels, bias)
74
  return clip(gradi, clip_norm)
75

76

77
def accumulate_grad(w, label_onehot, data, grad_accum, bias, clip_norm):
78
  update_fn = jax.vmap(lambda data, labels: clipped_log_likelihood_gradient(
79
      w, data, labels, bias, clip_norm))
80
  grad_all = update_fn(data, label_onehot)
81
  gradi = grad_all.sum(0)
82
  return grad_accum + gradi
83

84

85
def update_from_accum_grad(
86
    step, w, final_grad, batch_size, lr, apply_noise_fn=None, reg=0.0
87
):
88
  """Make an adam update from accumulated gradient."""
89
  final_grad = jax.lax.psum(final_grad, axis_name='batch')
90
  if apply_noise_fn is not None:
91
    final_grad = apply_noise_fn(final_grad, step)
92
  update = final_grad / batch_size
93
  b1 = 0.9
94
  b2 = 0.999
95
  update = ((1. - b1) /
96
            (1. - b2)) * update / ((jnp.sqrt(jax.lax.square(update))) + 1e-8)
97
  update += reg * w
98
  w -= lr * update
99
  return w, jnp.zeros_like(w)
100

101

102
def train_and_evaluate(config, workdir):
103
  """Top level training and eval loop."""
104

105
  tf.io.gfile.makedirs(workdir)
106
  start_step = 0
107

108
  writer = metric_writers.create_default_writer(
109
      workdir, just_logging=jax.process_index() > 0)
110
  if start_step == 0:
111
    writer.write_hparams(dict(config))
112

113
  num_epochs = config.num_epochs
114
  num_train_examples = 50000 if 'cifar' in config.dataset else 1281167
115
  local_batch_size = 1024
116
  num_acc_steps = num_train_examples // local_batch_size
117
  batch_size = local_batch_size * num_acc_steps
118
  num_steps_per_epoch = (num_train_examples // local_batch_size) + 1
119
  num_steps = num_steps_per_epoch * num_epochs
120
  print(f'num_steps: {num_steps}')
121
  print(f'num_steps_per_epoch: {num_steps_per_epoch}')
122
  print(f'lr: {config.lr}')
123
  print(f'num_acc_steps: {num_acc_steps}')
124
  print(f'batch_size: {batch_size}')
125

126
  data_config = data_utils.get_data_config(config)
127
  train_ds, test_ds = dataset.get_datasets(
128
      config=config,
129
      data_config=data_config,
130
      batch_size=local_batch_size,
131
      repeat=True
132
  )
133

134
  test_xs = []
135
  test_labels = []
136
  for x in test_ds:
137
    test_xs.append(x['repr'])
138
    test_labels.append(x['label'])
139
  test_x_np_list = utils.to_flat_np(
140
      test_xs, test_labels, data_config.num_labels
141
  )
142
  eval_step = jax.jit(
143
      functools.partial(
144
          utils.eval_step,
145
          test_x_np_list=test_x_np_list,
146
          hidden_dims=data_config.hidden_dims,
147
          num_labels=data_config.num_labels,
148
      ))
149

150
  # We only consider full batch setting.
151
  sigma = compute_noise_from_budget_lib.compute_noise(num_train_examples,
152
                                                      num_train_examples,
153
                                                      config.epsilon,
154
                                                      num_epochs,
155
                                                      data_config.delta, 1e-7)
156
  sigma *= data_config.clip
157
  key = jax.random.PRNGKey(config.seed)
158
  apply_noise_fn = None
159
  if config.is_private and config.epsilon > 0.0:
160
    apply_noise_fn = jax.vmap(functools.partial(noisy, s=sigma, key=key))
161
  update_from_accum_grad_partial = functools.partial(
162
      update_from_accum_grad,
163
      batch_size=batch_size,
164
      apply_noise_fn=apply_noise_fn,
165
      lr=config.lr,
166
      reg=config.reg)
167
  update_from_accum_grad_partial_pmapped = jax.pmap(
168
      update_from_accum_grad_partial, axis_name='batch')
169
  accumulate_grad_pmapped = jax.pmap(
170
      functools.partial(
171
          accumulate_grad, bias=-10.0, clip_norm=data_config.clip
172
      ),
173
      axis_name='batch',
174
  )
175

176
  grad_accum = np.zeros(
177
      (data_config.num_labels, data_config.hidden_dims), np.float32
178
  )
179
  grad_accum = jax.device_put_replicated(grad_accum, devices=jax.devices())
180
  wopt = np.zeros((data_config.num_labels, data_config.hidden_dims), np.float32)
181
  wopt = jax.device_put_replicated(wopt, devices=jax.devices())
182

183
  train_iter = train_ds.as_numpy_iterator()
184
  for i in range(1, num_steps + 1):
185
    x = next(train_iter)
186
    data = x['repr']
187
    data = np.reshape(data,
188
                      (jax.device_count(), data.shape[0] // jax.device_count(),
189
                       data_config.hidden_dims))
190
    label_onehot = np.array(one_hot(x['label'], data_config.num_labels))
191
    label_onehot = np.reshape(label_onehot,
192
                              (jax.device_count(), label_onehot.shape[0] //
193
                               jax.device_count(), data_config.num_labels))
194
    grad_accum = accumulate_grad_pmapped(wopt, label_onehot, data, grad_accum)
195

196
    if i and i % num_acc_steps == 0:
197
      step = np.array([i] * jax.device_count())
198
      wopt, grad_accum = update_from_accum_grad_partial_pmapped(
199
          step, wopt, grad_accum)
200
      wopt_for_eval = unreplicate_and_get(wopt)
201
      eval_acc = eval_step(wopt_for_eval)
202
      print(f'eval acc at step: {i}, {eval_acc}')
203
      summary = {}
204
      summary['accuracy'] = eval_acc
205
      with metric_writers.ensure_flushes(writer):
206
        writer.write_scalars(i, summary)
207

208

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.