google-research
181 строка · 5.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Linear regression DP sanitizer."""
17
18import functools
19
20import jax
21import jax.numpy as jnp
22import numpy as np
23
24
25def clip(x, clip_norm=1.0):
26divisor = jnp.maximum(jnp.linalg.norm(x) / clip_norm, 1.)
27return x / divisor
28
29
30def noisy(x, s, key):
31if 0 < s < np.inf:
32noise = jax.random.normal(key, shape=jnp.shape(x)) * s
33return x + noise
34return x
35
36
37class Sanitizer(object):
38"""Provides a set of functions for sanitizing private information.
39
40There are utilities used during pre-processing of the data, and others used
41during training. Sanitizer also provides a compute_epsilon function to compute
42the privacy loss for its sanitization process.
43
44Data pre-processing:
45preprocess_data: 1) centers ratings, 2) estimate counts, 3) restricts to
46head items, 4) adaptively samples items, 5) re-estimates counts on sampled
47data.
48
49Training:
50clip: clips the norm of each user embedding. Used during row solves.
51apply_noise: adds noise to a list of the input sufficient statistics
52involving the user data. Used during column solves.
53"""
54
55def __init__(self, max_norm=1, s0=0, s1=0, s2=0, random_seed=None):
56"""Initializes a Sanitizer.
57
58Args:
59max_norm: clips the user embeddings to max_norm.
60s0: the noise factor for global gramian.
61s1: the noise factor for local gramian.
62s2: the noise factor for rhs.
63random_seed: seed.
64"""
65if random_seed:
66self.key = jax.random.PRNGKey(random_seed)
67else:
68self.key = jax.random.PRNGKey(42)
69self.max_norm = max_norm
70self.sigmas = [s0, s1, s2]
71
72def refresh_key(self):
73"""Use PRNG key only once, this function refreshes it once its used."""
74_, self.key = jax.random.split(self.key)
75
76def clip(self, embeddings):
77if not self.max_norm:
78return embeddings
79return jax.vmap(functools.partial(clip, clip_norm=self.max_norm))(
80embeddings)
81
82def _project_psd(self, x, rank):
83"""Project a rank 2 or 3 tensor to PSD."""
84if rank == 2:
85indices = [1, 0]
86elif rank == 3:
87indices = [0, 2, 1]
88else:
89raise ValueError("rank must be 2 or 3")
90
91def transpose(x):
92return jnp.transpose(x, indices)
93
94x = (x + transpose(x)) / 2
95e, v = jnp.linalg.eigh(x)
96e = jnp.maximum(e, 0)
97return v @ (jnp.expand_dims(e, -1) * transpose(v))
98
99def apply_noise_gramian(self, global_gramian):
100sigmas = self.sigmas
101max_norm = self.max_norm
102sigma = sigmas[0] * (max_norm**2)
103
104gram = noisy(global_gramian, sigma, key=self.key)
105gram = self._project_psd(gram, rank=2)
106return gram
107
108def apply_noise(self, stats):
109"""Apply noise to stats."""
110if not isinstance(stats, (list, tuple)) and len(stats) != 2:
111raise ValueError("stats must a triple of (local_gramian, " "rhs).")
112if not self.max_norm:
113return stats
114sigmas = self.sigmas
115max_norm = self.max_norm
116sigmas = [sigmas[1] * (max_norm**2), sigmas[2] * max_norm]
117
118# Get fresh keys.
119keys = jax.random.split(self.key, num=2)
120
121lhs, rhs = [noisy(x, s, key=key) for x, s, key in zip(stats, sigmas, keys)]
122lhs = self._project_psd(lhs, rank=2)
123return lhs, rhs
124
125def compute_epsilon(self, target_delta):
126"""Computes epsilon."""
127if not all(self.sigmas):
128return np.inf
129# The accounting is done as follows: whenever we compute a statistic with
130# L2 sensitivity k and add Gaussian noise of scale σ, the procedure is
131# (α, αβ/2)-RDP with β = k²/σ². To compose RDP processes, we sum their β.
132# β accounts for running 1 step of ALS, which involves the following
133# computations:
134# - one global Gramian (k² = 1, σ = s0)
135# - local Gramians (k² = budget, σ = s1)
136# - RHSs (k² = budget, σ = s2)
137s0, s1, s2 = self.sigmas
138beta = 1 / (s0**2) + 1 / (s1**2) + 1 / (s2**2)
139# We translate (α, αβ/2)-RDP to (ε, δ)-DP with ε = αβ/2 + log(1/δ)/(α−1).
140# We pick the α that minimizes ε, which is α = 1 + √(2log(1/δ)/β)
141alpha = 1.0 + np.sqrt(np.log(1.0 / target_delta) * 2.0 / beta)
142eps = alpha * beta / 2.0 + np.log(1.0 / target_delta) / (alpha - 1.0)
143return eps
144
145def set_sigmas(self,
146target_epsilon,
147target_delta,
148sigma_ratio0=1,
149sigma_ratio1=1):
150"""Sets sigmas to get the target (epsilon, delta).
151
152Args:
153target_epsilon: the desired epsilon.
154target_delta: the desired delta.
155sigma_ratio0: the ratio sigma0/sigma2.
156sigma_ratio1: the ratio sigma1/sigma2.
157"""
158s_lower = 1e-6
159s_upper = 1e6
160
161def get_epsilon(s):
162self.sigmas = [sigma_ratio0 * s, sigma_ratio1 * s, s]
163return self.compute_epsilon(target_delta)
164
165eps = get_epsilon(s_lower)
166i = 0
167while np.abs(eps / target_epsilon - 1) > 0.0001:
168s = (s_lower + s_upper) / 2
169eps = get_epsilon(s)
170if eps > target_epsilon:
171s_lower = s
172else:
173s_upper = s
174i += 1
175if i > 1000:
176raise ValueError(
177f"No value of sigmas found for the desired (epsilon, delta)="
178f"={target_epsilon, target_delta}. Consider increasing stddev.")
179s0, s1, s2 = self.sigmas
180print(f"Setting sigmas to [{s0:.2f}, {s1:.2f}, {s2:.2f}], given target "
181f"(epsilon, delta)={target_epsilon, target_delta}")
182