google-research
123 строки · 4.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Optimizers."""
17
18import chex19import jax20import jax.numpy as jnp21import optax22
23
24def clip_by_norm(updates,25l2_norms_threshold):26"""Standard clipping by L2 norm."""27
28grad_norms = jax.tree_map(29jax.vmap(jnp.linalg.norm),30updates)31divisors = jax.tree_map(32lambda g_norm, l2_norm_clip: jnp.maximum(g_norm / l2_norm_clip, 1.0),33grad_norms, l2_norms_threshold)34return jax.tree_map(35jax.vmap(lambda g, div: g / div),36updates, divisors)37
38
39def dp_aggregate(40l2_norms_threshold,41base_sensitivity,42noise_multiplier,43init_rng,44):45"""Aggregates gradients based on the DP-SGD algorithm.46
47This method clips per-example gradients to some l2 norm, sums them up,
48and adds noise to the sum.
49
50WARNING: Unlike other transforms, `dp_aggregate` expects
51the input updates to have a batch dimension in the 0th axis. That is, this
52function expects per-example gradients as input (which are easy to obtain in
53JAX using `jax.vmap`). It can still be composed with other transformations as
54long as it is the first in the chain.
55Further, each per-example gradient must already be divided by the batch size.
56
57References:
58[Abadi et al, 2016](https://arxiv.org/abs/1607.00133)
59
60Args:
61l2_norms_threshold: max L2 norm of the per-example gradients for each layer.
62base_sensitivity: ratio of sensitivity to the clipping norm.
63noise_multiplier: ratio of noise standard deviation to the sensitivity.
64init_rng: initial jax.random.PRNGKey
65
66Returns:
67A `GradientTransformation`.
68"""
69noise_stds = jax.tree_map(70lambda l2_norm_clip: l2_norm_clip * base_sensitivity * noise_multiplier,71l2_norms_threshold)72
73def init_fn(params):74del params75return optax.DifferentiallyPrivateAggregateState(76rng_key=init_rng)77
78def update_fn(updates, state, params):79del params80grads_flat, grads_treedef = jax.tree_flatten(updates)81batch_size = grads_flat[0].shape[0]82
83if any(g.ndim == 0 or batch_size != g.shape[0] for g in grads_flat):84raise ValueError(85'Unlike other transforms, `dp_aggregate` expects'86' `updates` to have a batch dimension in the 0th axis. That is, this'87' function expects per-example gradients as input.')88
89new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat) + 1)90rng_tree = jax.tree_unflatten(grads_treedef, rngs)91
92clipped_updates = clip_by_norm(updates, l2_norms_threshold)93summed_updates = jax.tree_map(94lambda g: jnp.sum(g, axis=0),95clipped_updates)96noise = jax.tree_map(97lambda g, std, rng: (std * jax.random.normal(rng, g.shape, g.dtype)),98summed_updates, noise_stds, rng_tree)99noisy_updates = jax.tree_map(lambda g, noise: (g + noise), summed_updates,100noise)101return (noisy_updates,102optax.DifferentiallyPrivateAggregateState(rng_key=new_key))103
104return optax.GradientTransformation(init_fn, update_fn)105
106
107def dpsgd(learning_rate, l2_norms_threshold,108base_sensitivity, noise_multiplier,109init_rng, momentum,110nesterov):111"""A differentially-private version of SGD."""112return optax.chain(113dp_aggregate(l2_norms_threshold, base_sensitivity, noise_multiplier,114init_rng), optax.sgd(learning_rate, momentum, nesterov))115
116
117def dpadam(learning_rate, l2_norms_threshold,118base_sensitivity, noise_multiplier,119init_rng):120"""A differentially-private version of Adam."""121return optax.chain(122dp_aggregate(l2_norms_threshold, base_sensitivity, noise_multiplier,123init_rng), optax.adam(learning_rate))124