google-research
181 строка · 5.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Newton step DP sanitizer."""
17
18import functools19
20import jax21import jax.numpy as jnp22import numpy as np23
24
25def clip(x, clip_norm=1.0):26divisor = jnp.maximum(jnp.linalg.norm(x) / clip_norm, 1.)27return x / divisor28
29
30def noisy(x, s, key):31if 0 < s < np.inf:32key, subkey = jax.random.split(key)33noise = jax.random.normal(subkey, shape=jnp.shape(x)) * s34return x + noise35return x36
37
38class Sanitizer(object):39"""Provides a set of functions for sanitizing private information.40
41There are utilities used during pre-processing of the data, and others used
42during training. Sanitizer also provides a compute_epsilon function to compute
43the privacy loss for its sanitization process.
44
45Training:
46clip: clips the norm of each user embedding. Used during row solves.
47apply_noise: adds noise to a list of the input sufficient statistics
48involving the user data. Used during column solves.
49"""
50
51def __init__(self,52steps,53max_norm=1,54num_classes=1000,55s1=0,56s2=0,57random_seed=None):58"""Initializes a Sanitizer.59
60Args:
61steps: number of optimization steps.
62max_norm: clips the user embeddings to max_norm.
63num_classes: number of classifier classes.
64s1: the noise factor for local gramian.
65s2: the noise factor for rhs.
66random_seed: rng seed for sampling.
67"""
68if random_seed:69self.key = jax.random.PRNGKey(random_seed)70else:71self.key = jax.random.PRNGKey(42)72self.max_norm = max_norm73self.sigmas = [s1, s2]74self.steps = steps75self.num_classes = num_classes76
77def refresh_key(self):78"""Use PRNG key only once, this function refreshes it once its used."""79_, self.key = jax.random.split(self.key)80
81def clip(self, embeddings):82if not self.max_norm:83return embeddings84return jax.vmap(functools.partial(clip, clip_norm=self.max_norm))(85embeddings)86
87def _project_psd(self, x, rank):88"""Project a rank 2 or 3 tensor to PSD."""89if rank == 2:90indices = [1, 0]91elif rank == 3:92indices = [0, 2, 1]93else:94raise ValueError("rank must be 2 or 3")95
96def transpose(x):97return jnp.transpose(x, indices)98
99x = (x + transpose(x)) / 2100e, v = jnp.linalg.eigh(x)101e = jnp.maximum(e, 0)102return v @ (jnp.expand_dims(e, -1) * transpose(v))103
104def apply_noise(self, stats):105"""Apply noise to stats."""106if not isinstance(stats, (list, tuple)) and len(stats) != 2:107raise ValueError("stats must a triple of (local_gramian, " "rhs).")108if not self.max_norm:109return stats110sigmas = self.sigmas111max_norm = self.max_norm112num_classes_scale = np.sqrt(self.num_classes)113num_classes_scale_hessian = num_classes_scale114num_classes_scale_gradient = num_classes_scale115sigmas = [116num_classes_scale_hessian * sigmas[0] * (max_norm**2) / 4.0,117num_classes_scale_gradient * sigmas[1] * max_norm118]119# Get fresh keys.120keys = jax.random.split(self.key, num=2)121
122lhs, rhs = [123noisy(x, s, key=key)124for x, s, key in zip(stats, sigmas, keys)125]126lhs = self._project_psd(lhs, rank=2)127return lhs, rhs128
129def compute_epsilon(self, target_delta):130"""Computes epsilon."""131if not all(self.sigmas):132return np.inf133# The accounting is done as follows: whenever we compute a statistic with134# L2 sensitivity k and add Gaussian noise of scale σ, the procedure is135# (α, αβ/2)-RDP with β = k²/σ². To compose RDP processes, we sum their β.136s1, s2 = self.sigmas137s1_multiplier = self.steps138s2_multiplier = self.steps139beta = (s1_multiplier / (s1**2) + s2_multiplier / (s2**2))140# We translate (α, αβ/2)-RDP to (ε, δ)-DP with ε = αβ/2 + log(1/δ)/(α−1).141# We pick the α that minimizes ε, which is α = 1 + √(2log(1/δ)/β)142alpha = 1.0 + np.sqrt(np.log(1.0 / target_delta) * 2.0 / beta)143eps = alpha * beta / 2.0 + np.log(1.0 / target_delta) / (alpha - 1.0)144return eps145
146def set_sigmas(self,147target_epsilon,148target_delta,149sigma_ratio1=1):150"""Sets sigmas to get the target (epsilon, delta).151
152Args:
153target_epsilon: the desired epsilon.
154target_delta: the desired delta.
155sigma_ratio1: the ratio sigma1/sigma2.
156"""
157s_lower = 1e-6158s_upper = 1e6159
160def get_epsilon(s):161self.sigmas = [sigma_ratio1 * s, s]162return self.compute_epsilon(target_delta)163
164eps = get_epsilon(s_lower)165i = 0166while np.abs(eps / target_epsilon - 1) > 0.0001:167s = (s_lower + s_upper) / 2168eps = get_epsilon(s)169if eps > target_epsilon:170s_lower = s171else:172s_upper = s173i += 1174if i > 1000:175raise ValueError(176f"No value of sigmas found for the desired (epsilon, delta)="177f"={target_epsilon, target_delta}. Consider increasing stddev.")178s1, s2 = self.sigmas179print(180f"Setting sigmas to [{s1:.2f}, {s2:.2f}], given target "181f"(epsilon, delta)={target_epsilon, target_delta}")182