google-research

Форк
0
/
linear_regression_sanitizer.py 
181 строка · 5.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Linear regression DP sanitizer."""
17

18
import functools
19

20
import jax
21
import jax.numpy as jnp
22
import numpy as np
23

24

25
def clip(x, clip_norm=1.0):
26
  divisor = jnp.maximum(jnp.linalg.norm(x) / clip_norm, 1.)
27
  return x / divisor
28

29

30
def noisy(x, s, key):
31
  if 0 < s < np.inf:
32
    noise = jax.random.normal(key, shape=jnp.shape(x)) * s
33
    return x + noise
34
  return x
35

36

37
class Sanitizer(object):
38
  """Provides a set of functions for sanitizing private information.
39

40
  There are utilities used during pre-processing of the data, and others used
41
  during training. Sanitizer also provides a compute_epsilon function to compute
42
  the privacy loss for its sanitization process.
43

44
  Data pre-processing:
45
    preprocess_data: 1) centers ratings, 2) estimate counts, 3) restricts to
46
      head items, 4) adaptively samples items, 5) re-estimates counts on sampled
47
      data.
48

49
  Training:
50
    clip: clips the norm of each user embedding. Used during row solves.
51
    apply_noise: adds noise to a list of the input sufficient statistics
52
      involving the user data. Used during column solves.
53
  """
54

55
  def __init__(self, max_norm=1, s0=0, s1=0, s2=0, random_seed=None):
56
    """Initializes a Sanitizer.
57

58
    Args:
59
      max_norm: clips the user embeddings to max_norm.
60
      s0: the noise factor for global gramian.
61
      s1: the noise factor for local gramian.
62
      s2: the noise factor for rhs.
63
      random_seed: seed.
64
    """
65
    if random_seed:
66
      self.key = jax.random.PRNGKey(random_seed)
67
    else:
68
      self.key = jax.random.PRNGKey(42)
69
    self.max_norm = max_norm
70
    self.sigmas = [s0, s1, s2]
71

72
  def refresh_key(self):
73
    """Use PRNG key only once, this function refreshes it once its used."""
74
    _, self.key = jax.random.split(self.key)
75

76
  def clip(self, embeddings):
77
    if not self.max_norm:
78
      return embeddings
79
    return jax.vmap(functools.partial(clip, clip_norm=self.max_norm))(
80
        embeddings)
81

82
  def _project_psd(self, x, rank):
83
    """Project a rank 2 or 3 tensor to PSD."""
84
    if rank == 2:
85
      indices = [1, 0]
86
    elif rank == 3:
87
      indices = [0, 2, 1]
88
    else:
89
      raise ValueError("rank must be 2 or 3")
90

91
    def transpose(x):
92
      return jnp.transpose(x, indices)
93

94
    x = (x + transpose(x)) / 2
95
    e, v = jnp.linalg.eigh(x)
96
    e = jnp.maximum(e, 0)
97
    return v @ (jnp.expand_dims(e, -1) * transpose(v))
98

99
  def apply_noise_gramian(self, global_gramian):
100
    sigmas = self.sigmas
101
    max_norm = self.max_norm
102
    sigma = sigmas[0] * (max_norm**2)
103

104
    gram = noisy(global_gramian, sigma, key=self.key)
105
    gram = self._project_psd(gram, rank=2)
106
    return gram
107

108
  def apply_noise(self, stats):
109
    """Apply noise to stats."""
110
    if not isinstance(stats, (list, tuple)) and len(stats) != 2:
111
      raise ValueError("stats must a triple of (local_gramian, " "rhs).")
112
    if not self.max_norm:
113
      return stats
114
    sigmas = self.sigmas
115
    max_norm = self.max_norm
116
    sigmas = [sigmas[1] * (max_norm**2), sigmas[2] * max_norm]
117

118
    # Get fresh keys.
119
    keys = jax.random.split(self.key, num=2)
120

121
    lhs, rhs = [noisy(x, s, key=key) for x, s, key in zip(stats, sigmas, keys)]
122
    lhs = self._project_psd(lhs, rank=2)
123
    return lhs, rhs
124

125
  def compute_epsilon(self, target_delta):
126
    """Computes epsilon."""
127
    if not all(self.sigmas):
128
      return np.inf
129
    # The accounting is done as follows: whenever we compute a statistic with
130
    # L2 sensitivity k and add Gaussian noise of scale σ, the procedure is
131
    # (α, αβ/2)-RDP with β = k²/σ². To compose RDP processes, we sum their β.
132
    # β accounts for running 1 step of ALS, which involves the following
133
    # computations:
134
    # - one global Gramian (k² = 1, σ = s0)
135
    # - local Gramians (k² = budget, σ = s1)
136
    # - RHSs (k² = budget, σ = s2)
137
    s0, s1, s2 = self.sigmas
138
    beta = 1 / (s0**2) + 1 / (s1**2) + 1 / (s2**2)
139
    # We translate (α, αβ/2)-RDP to (ε, δ)-DP with ε = αβ/2 + log(1/δ)/(α−1).
140
    # We pick the α that minimizes ε, which is α = 1 + √(2log(1/δ)/β)
141
    alpha = 1.0 + np.sqrt(np.log(1.0 / target_delta) * 2.0 / beta)
142
    eps = alpha * beta / 2.0 + np.log(1.0 / target_delta) / (alpha - 1.0)
143
    return eps
144

145
  def set_sigmas(self,
146
                 target_epsilon,
147
                 target_delta,
148
                 sigma_ratio0=1,
149
                 sigma_ratio1=1):
150
    """Sets sigmas to get the target (epsilon, delta).
151

152
    Args:
153
      target_epsilon: the desired epsilon.
154
      target_delta: the desired delta.
155
      sigma_ratio0: the ratio sigma0/sigma2.
156
      sigma_ratio1: the ratio sigma1/sigma2.
157
    """
158
    s_lower = 1e-6
159
    s_upper = 1e6
160

161
    def get_epsilon(s):
162
      self.sigmas = [sigma_ratio0 * s, sigma_ratio1 * s, s]
163
      return self.compute_epsilon(target_delta)
164

165
    eps = get_epsilon(s_lower)
166
    i = 0
167
    while np.abs(eps / target_epsilon - 1) > 0.0001:
168
      s = (s_lower + s_upper) / 2
169
      eps = get_epsilon(s)
170
      if eps > target_epsilon:
171
        s_lower = s
172
      else:
173
        s_upper = s
174
      i += 1
175
      if i > 1000:
176
        raise ValueError(
177
            f"No value of sigmas found for the desired (epsilon, delta)="
178
            f"={target_epsilon, target_delta}. Consider increasing stddev.")
179
    s0, s1, s2 = self.sigmas
180
    print(f"Setting sigmas to [{s0:.2f}, {s1:.2f}, {s2:.2f}], given target "
181
          f"(epsilon, delta)={target_epsilon, target_delta}")
182

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.