google-research
206 строк · 6.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main train loop for DP-Adam. File intended to be mostly self-contained."""
17
18import functools19
20from clu import metric_writers21from flax import jax_utils22import jax23import jax.numpy as jnp24import jax.profiler25import ml_collections26import numpy as np27import tensorflow as tf28from tensorflow_privacy.privacy.analysis import compute_noise_from_budget_lib29
30from dp_transfer import data_utils31from dp_transfer import dataset32from dp_transfer import utils33
34
35
36def unreplicate_and_get(x):37return jax.device_get(jax_utils.unreplicate(x))38
39
40def noisy(step, x, s, key):41if 0 < s < np.inf:42new_key = jax.random.fold_in(key, step)43noise = jax.random.normal(new_key, shape=jnp.shape(x)) * s44return x + noise45return x46
47
48def one_hot(a, num_classes):49return np.squeeze(np.eye(num_classes)[a.reshape(-1)])50
51
52def log_likelihood(weights, data, labels, bias):53"""Normalized negative log likelihood."""54logits = jnp.einsum('d,ld->l', data, weights) + bias55log_p, log_not_p = jax.nn.log_sigmoid(logits), jax.nn.log_sigmoid(-logits)56
57loss = -((labels * log_p) + (1. - labels) * log_not_p)58return jnp.mean(loss)59
60
61def log_likelihood_gradient(weights, data, labels, bias):62"""Gradient of negative log likelihood."""63return jax.grad(lambda w: log_likelihood(w, data, labels, bias))(weights)64
65
66def clip(x, clip_norm=1.0):67divisor = jnp.maximum(jnp.linalg.norm(x) / clip_norm, 1.)68return x / divisor69
70
71def clipped_log_likelihood_gradient(weights, data, labels, bias, clip_norm):72"""Gradient of negative log likelihood."""73gradi = log_likelihood_gradient(weights, data, labels, bias)74return clip(gradi, clip_norm)75
76
77def accumulate_grad(w, label_onehot, data, grad_accum, bias, clip_norm):78update_fn = jax.vmap(lambda data, labels: clipped_log_likelihood_gradient(79w, data, labels, bias, clip_norm))80grad_all = update_fn(data, label_onehot)81gradi = grad_all.sum(0)82return grad_accum + gradi83
84
85def update_from_accum_grad(86step, w, final_grad, batch_size, lr, apply_noise_fn=None, reg=0.087):88"""Make an adam update from accumulated gradient."""89final_grad = jax.lax.psum(final_grad, axis_name='batch')90if apply_noise_fn is not None:91final_grad = apply_noise_fn(final_grad, step)92update = final_grad / batch_size93b1 = 0.994b2 = 0.99995update = ((1. - b1) /96(1. - b2)) * update / ((jnp.sqrt(jax.lax.square(update))) + 1e-8)97update += reg * w98w -= lr * update99return w, jnp.zeros_like(w)100
101
102def train_and_evaluate(config, workdir):103"""Top level training and eval loop."""104
105tf.io.gfile.makedirs(workdir)106start_step = 0107
108writer = metric_writers.create_default_writer(109workdir, just_logging=jax.process_index() > 0)110if start_step == 0:111writer.write_hparams(dict(config))112
113num_epochs = config.num_epochs114num_train_examples = 50000 if 'cifar' in config.dataset else 1281167115local_batch_size = 1024116num_acc_steps = num_train_examples // local_batch_size117batch_size = local_batch_size * num_acc_steps118num_steps_per_epoch = (num_train_examples // local_batch_size) + 1119num_steps = num_steps_per_epoch * num_epochs120print(f'num_steps: {num_steps}')121print(f'num_steps_per_epoch: {num_steps_per_epoch}')122print(f'lr: {config.lr}')123print(f'num_acc_steps: {num_acc_steps}')124print(f'batch_size: {batch_size}')125
126data_config = data_utils.get_data_config(config)127train_ds, test_ds = dataset.get_datasets(128config=config,129data_config=data_config,130batch_size=local_batch_size,131repeat=True132)133
134test_xs = []135test_labels = []136for x in test_ds:137test_xs.append(x['repr'])138test_labels.append(x['label'])139test_x_np_list = utils.to_flat_np(140test_xs, test_labels, data_config.num_labels141)142eval_step = jax.jit(143functools.partial(144utils.eval_step,145test_x_np_list=test_x_np_list,146hidden_dims=data_config.hidden_dims,147num_labels=data_config.num_labels,148))149
150# We only consider full batch setting.151sigma = compute_noise_from_budget_lib.compute_noise(num_train_examples,152num_train_examples,153config.epsilon,154num_epochs,155data_config.delta, 1e-7)156sigma *= data_config.clip157key = jax.random.PRNGKey(config.seed)158apply_noise_fn = None159if config.is_private and config.epsilon > 0.0:160apply_noise_fn = jax.vmap(functools.partial(noisy, s=sigma, key=key))161update_from_accum_grad_partial = functools.partial(162update_from_accum_grad,163batch_size=batch_size,164apply_noise_fn=apply_noise_fn,165lr=config.lr,166reg=config.reg)167update_from_accum_grad_partial_pmapped = jax.pmap(168update_from_accum_grad_partial, axis_name='batch')169accumulate_grad_pmapped = jax.pmap(170functools.partial(171accumulate_grad, bias=-10.0, clip_norm=data_config.clip172),173axis_name='batch',174)175
176grad_accum = np.zeros(177(data_config.num_labels, data_config.hidden_dims), np.float32178)179grad_accum = jax.device_put_replicated(grad_accum, devices=jax.devices())180wopt = np.zeros((data_config.num_labels, data_config.hidden_dims), np.float32)181wopt = jax.device_put_replicated(wopt, devices=jax.devices())182
183train_iter = train_ds.as_numpy_iterator()184for i in range(1, num_steps + 1):185x = next(train_iter)186data = x['repr']187data = np.reshape(data,188(jax.device_count(), data.shape[0] // jax.device_count(),189data_config.hidden_dims))190label_onehot = np.array(one_hot(x['label'], data_config.num_labels))191label_onehot = np.reshape(label_onehot,192(jax.device_count(), label_onehot.shape[0] //193jax.device_count(), data_config.num_labels))194grad_accum = accumulate_grad_pmapped(wopt, label_onehot, data, grad_accum)195
196if i and i % num_acc_steps == 0:197step = np.array([i] * jax.device_count())198wopt, grad_accum = update_from_accum_grad_partial_pmapped(199step, wopt, grad_accum)200wopt_for_eval = unreplicate_and_get(wopt)201eval_acc = eval_step(wopt_for_eval)202print(f'eval acc at step: {i}, {eval_acc}')203summary = {}204summary['accuracy'] = eval_acc205with metric_writers.ensure_flushes(writer):206writer.write_scalars(i, summary)207
208