google-research

Форк
0
123 строки · 4.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Optimizers."""
17

18
import chex
19
import jax
20
import jax.numpy as jnp
21
import optax
22

23

24
def clip_by_norm(updates,
25
                 l2_norms_threshold):
26
  """Standard clipping by L2 norm."""
27

28
  grad_norms = jax.tree_map(
29
      jax.vmap(jnp.linalg.norm),
30
      updates)
31
  divisors = jax.tree_map(
32
      lambda g_norm, l2_norm_clip: jnp.maximum(g_norm / l2_norm_clip, 1.0),
33
      grad_norms, l2_norms_threshold)
34
  return jax.tree_map(
35
      jax.vmap(lambda g, div: g / div),
36
      updates, divisors)
37

38

39
def dp_aggregate(
40
    l2_norms_threshold,
41
    base_sensitivity,
42
    noise_multiplier,
43
    init_rng,
44
):
45
  """Aggregates gradients based on the DP-SGD algorithm.
46

47
  This method clips per-example gradients to some l2 norm, sums them up,
48
  and adds noise to the sum.
49

50
  WARNING: Unlike other transforms, `dp_aggregate` expects
51
  the input updates to have a batch dimension in the 0th axis. That is, this
52
  function expects per-example gradients as input (which are easy to obtain in
53
  JAX using `jax.vmap`). It can still be composed with other transformations as
54
  long as it is the first in the chain.
55
  Further, each per-example gradient must already be divided by the batch size.
56

57
  References:
58
    [Abadi et al, 2016](https://arxiv.org/abs/1607.00133)
59

60
  Args:
61
    l2_norms_threshold: max L2 norm of the per-example gradients for each layer.
62
    base_sensitivity: ratio of sensitivity to the clipping norm.
63
    noise_multiplier: ratio of noise standard deviation to the sensitivity.
64
    init_rng: initial jax.random.PRNGKey
65

66
  Returns:
67
    A `GradientTransformation`.
68
  """
69
  noise_stds = jax.tree_map(
70
      lambda l2_norm_clip: l2_norm_clip * base_sensitivity * noise_multiplier,
71
      l2_norms_threshold)
72

73
  def init_fn(params):
74
    del params
75
    return optax.DifferentiallyPrivateAggregateState(
76
        rng_key=init_rng)
77

78
  def update_fn(updates, state, params):
79
    del params
80
    grads_flat, grads_treedef = jax.tree_flatten(updates)
81
    batch_size = grads_flat[0].shape[0]
82

83
    if any(g.ndim == 0 or batch_size != g.shape[0] for g in grads_flat):
84
      raise ValueError(
85
          'Unlike other transforms, `dp_aggregate` expects'
86
          ' `updates` to have a batch dimension in the 0th axis. That is, this'
87
          ' function expects per-example gradients as input.')
88

89
    new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat) + 1)
90
    rng_tree = jax.tree_unflatten(grads_treedef, rngs)
91

92
    clipped_updates = clip_by_norm(updates, l2_norms_threshold)
93
    summed_updates = jax.tree_map(
94
        lambda g: jnp.sum(g, axis=0),
95
        clipped_updates)
96
    noise = jax.tree_map(
97
        lambda g, std, rng: (std * jax.random.normal(rng, g.shape, g.dtype)),
98
        summed_updates, noise_stds, rng_tree)
99
    noisy_updates = jax.tree_map(lambda g, noise: (g + noise), summed_updates,
100
                                 noise)
101
    return (noisy_updates,
102
            optax.DifferentiallyPrivateAggregateState(rng_key=new_key))
103

104
  return optax.GradientTransformation(init_fn, update_fn)
105

106

107
def dpsgd(learning_rate, l2_norms_threshold,
108
          base_sensitivity, noise_multiplier,
109
          init_rng, momentum,
110
          nesterov):
111
  """A differentially-private version of SGD."""
112
  return optax.chain(
113
      dp_aggregate(l2_norms_threshold, base_sensitivity, noise_multiplier,
114
                   init_rng), optax.sgd(learning_rate, momentum, nesterov))
115

116

117
def dpadam(learning_rate, l2_norms_threshold,
118
           base_sensitivity, noise_multiplier,
119
           init_rng):
120
  """A differentially-private version of Adam."""
121
  return optax.chain(
122
      dp_aggregate(l2_norms_threshold, base_sensitivity, noise_multiplier,
123
                   init_rng), optax.adam(learning_rate))
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.