google-research

Форк
0
/
newton_sanitizer.py 
181 строка · 5.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Newton step DP sanitizer."""
17

18
import functools
19

20
import jax
21
import jax.numpy as jnp
22
import numpy as np
23

24

25
def clip(x, clip_norm=1.0):
26
  divisor = jnp.maximum(jnp.linalg.norm(x) / clip_norm, 1.)
27
  return x / divisor
28

29

30
def noisy(x, s, key):
31
  if 0 < s < np.inf:
32
    key, subkey = jax.random.split(key)
33
    noise = jax.random.normal(subkey, shape=jnp.shape(x)) * s
34
    return x + noise
35
  return x
36

37

38
class Sanitizer(object):
39
  """Provides a set of functions for sanitizing private information.
40

41
  There are utilities used during pre-processing of the data, and others used
42
  during training. Sanitizer also provides a compute_epsilon function to compute
43
  the privacy loss for its sanitization process.
44

45
  Training:
46
    clip: clips the norm of each user embedding. Used during row solves.
47
    apply_noise: adds noise to a list of the input sufficient statistics
48
      involving the user data. Used during column solves.
49
  """
50

51
  def __init__(self,
52
               steps,
53
               max_norm=1,
54
               num_classes=1000,
55
               s1=0,
56
               s2=0,
57
               random_seed=None):
58
    """Initializes a Sanitizer.
59

60
    Args:
61
      steps: number of optimization steps.
62
      max_norm: clips the user embeddings to max_norm.
63
      num_classes: number of classifier classes.
64
      s1: the noise factor for local gramian.
65
      s2: the noise factor for rhs.
66
      random_seed: rng seed for sampling.
67
    """
68
    if random_seed:
69
      self.key = jax.random.PRNGKey(random_seed)
70
    else:
71
      self.key = jax.random.PRNGKey(42)
72
    self.max_norm = max_norm
73
    self.sigmas = [s1, s2]
74
    self.steps = steps
75
    self.num_classes = num_classes
76

77
  def refresh_key(self):
78
    """Use PRNG key only once, this function refreshes it once its used."""
79
    _, self.key = jax.random.split(self.key)
80

81
  def clip(self, embeddings):
82
    if not self.max_norm:
83
      return embeddings
84
    return jax.vmap(functools.partial(clip, clip_norm=self.max_norm))(
85
        embeddings)
86

87
  def _project_psd(self, x, rank):
88
    """Project a rank 2 or 3 tensor to PSD."""
89
    if rank == 2:
90
      indices = [1, 0]
91
    elif rank == 3:
92
      indices = [0, 2, 1]
93
    else:
94
      raise ValueError("rank must be 2 or 3")
95

96
    def transpose(x):
97
      return jnp.transpose(x, indices)
98

99
    x = (x + transpose(x)) / 2
100
    e, v = jnp.linalg.eigh(x)
101
    e = jnp.maximum(e, 0)
102
    return v @ (jnp.expand_dims(e, -1) * transpose(v))
103

104
  def apply_noise(self, stats):
105
    """Apply noise to stats."""
106
    if not isinstance(stats, (list, tuple)) and len(stats) != 2:
107
      raise ValueError("stats must a triple of (local_gramian, " "rhs).")
108
    if not self.max_norm:
109
      return stats
110
    sigmas = self.sigmas
111
    max_norm = self.max_norm
112
    num_classes_scale = np.sqrt(self.num_classes)
113
    num_classes_scale_hessian = num_classes_scale
114
    num_classes_scale_gradient = num_classes_scale
115
    sigmas = [
116
        num_classes_scale_hessian * sigmas[0] * (max_norm**2) / 4.0,
117
        num_classes_scale_gradient * sigmas[1] * max_norm
118
    ]
119
    # Get fresh keys.
120
    keys = jax.random.split(self.key, num=2)
121

122
    lhs, rhs = [
123
        noisy(x, s, key=key)
124
        for x, s, key in zip(stats, sigmas, keys)
125
    ]
126
    lhs = self._project_psd(lhs, rank=2)
127
    return lhs, rhs
128

129
  def compute_epsilon(self, target_delta):
130
    """Computes epsilon."""
131
    if not all(self.sigmas):
132
      return np.inf
133
    # The accounting is done as follows: whenever we compute a statistic with
134
    # L2 sensitivity k and add Gaussian noise of scale σ, the procedure is
135
    # (α, αβ/2)-RDP with β = k²/σ². To compose RDP processes, we sum their β.
136
    s1, s2 = self.sigmas
137
    s1_multiplier = self.steps
138
    s2_multiplier = self.steps
139
    beta = (s1_multiplier / (s1**2) + s2_multiplier / (s2**2))
140
    # We translate (α, αβ/2)-RDP to (ε, δ)-DP with ε = αβ/2 + log(1/δ)/(α−1).
141
    # We pick the α that minimizes ε, which is α = 1 + √(2log(1/δ)/β)
142
    alpha = 1.0 + np.sqrt(np.log(1.0 / target_delta) * 2.0 / beta)
143
    eps = alpha * beta / 2.0 + np.log(1.0 / target_delta) / (alpha - 1.0)
144
    return eps
145

146
  def set_sigmas(self,
147
                 target_epsilon,
148
                 target_delta,
149
                 sigma_ratio1=1):
150
    """Sets sigmas to get the target (epsilon, delta).
151

152
    Args:
153
      target_epsilon: the desired epsilon.
154
      target_delta: the desired delta.
155
      sigma_ratio1: the ratio sigma1/sigma2.
156
    """
157
    s_lower = 1e-6
158
    s_upper = 1e6
159

160
    def get_epsilon(s):
161
      self.sigmas = [sigma_ratio1 * s, s]
162
      return self.compute_epsilon(target_delta)
163

164
    eps = get_epsilon(s_lower)
165
    i = 0
166
    while np.abs(eps / target_epsilon - 1) > 0.0001:
167
      s = (s_lower + s_upper) / 2
168
      eps = get_epsilon(s)
169
      if eps > target_epsilon:
170
        s_lower = s
171
      else:
172
        s_upper = s
173
      i += 1
174
      if i > 1000:
175
        raise ValueError(
176
            f"No value of sigmas found for the desired (epsilon, delta)="
177
            f"={target_epsilon, target_delta}. Consider increasing stddev.")
178
    s1, s2 = self.sigmas
179
    print(
180
        f"Setting sigmas to [{s1:.2f}, {s2:.2f}], given target "
181
        f"(epsilon, delta)={target_epsilon, target_delta}")
182

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.