google-research
507 строк · 15.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""TriMap: Large-scale Dimensionality Reduction Using Triplets.
17
18Source: https://arxiv.org/pdf/1910.00204.pdf
19"""
20
21import datetime22import time23from typing import Mapping24
25from absl import logging26import jax27import jax.numpy as jnp28import jax.random as random29import numpy as np30import pynndescent31from sklearn.decomposition import PCA32from sklearn.decomposition import TruncatedSVD33
34_DIM_PCA = 10035_INIT_SCALE = 0.0136_INIT_MOMENTUM = 0.537_FINAL_MOMENTUM = 0.838_SWITCH_ITER = 25039_MIN_GAIN = 0.0140_INCREASE_GAIN = 0.241_DAMP_GAIN = 0.842_DISPLAY_ITER = 10043
44
45def tempered_log(x, t):46"""Tempered log with temperature t."""47if jnp.abs(t - 1.0) < 1e-5:48return jnp.log(x)49else:50return 1. / (1. - t) * (jnp.power(x, 1.0 - t) - 1.0)51
52
53def get_distance_fn(distance_fn_name):54"""Get the distance function."""55if distance_fn_name == 'euclidean':56return euclidean_dist57elif distance_fn_name == 'manhattan':58return manhattan_dist59elif distance_fn_name == 'cosine':60return cosine_dist61elif distance_fn_name == 'hamming':62return hamming_dist63elif distance_fn_name == 'chebyshev':64return chebyshev_dist65else:66raise ValueError(f'Distance function {distance_fn_name} not supported.')67
68
69def sliced_distances(70indices1,71indices2,72inputs,73distance_fn):74"""Applies distance_fn in smaller slices to avoid memory blow-ups.75
76Args:
77indices1: First array of indices.
78indices2: Second array of indices.
79inputs: 2-D array of inputs.
80distance_fn: Distance function that applies row-wise.
81
82Returns:
83Pairwise distances between the row indices in indices1 and indices2.
84"""
85slice_size = inputs.shape[0]86distances = []87num_slices = int(np.ceil(len(indices1) / slice_size))88for slice_id in range(num_slices):89start = slice_id * slice_size90end = (slice_id + 1) * slice_size91distances.append(92distance_fn(inputs[indices1[start:end]], inputs[indices2[start:end]]))93return jnp.concatenate(distances)94
95
96@jax.jit97def squared_euclidean_dist(x1, x2):98"""Squared Euclidean distance between rows of x1 and x2."""99return jnp.sum(jnp.power(x1 - x2, 2), axis=-1)100
101
102@jax.jit103def euclidean_dist(x1, x2):104"""Euclidean distance between rows of x1 and x2."""105return jnp.sqrt(jnp.sum(jnp.power(x1 - x2, 2), axis=-1))106
107
108@jax.jit109def manhattan_dist(x1, x2):110"""Manhattan distance between rows of x1 and x2."""111return jnp.sum(jnp.abs(x1 - x2), axis=-1)112
113
114@jax.jit115def cosine_dist(x1, x2):116"""Cosine (i.e. cosine) between rows of x1 and x2."""117x1_norm = jnp.maximum(jnp.linalg.norm(x1, axis=-1), 1e-20)118x2_norm = jnp.maximum(jnp.linalg.norm(x2, axis=-1), 1e-20)119return 1. - jnp.sum(x1 * x2, -1) / x1_norm / x2_norm120
121
122@jax.jit123def hamming_dist(x1, x2):124"""Hamming distance between two vectors."""125return jnp.sum(x1 != x2, axis=-1)126
127
128@jax.jit129def chebyshev_dist(x1, x2):130"""Chebyshev distance between two vectors."""131return jnp.max(jnp.abs(x1 - x2), -1)132
133
134def rejection_sample(key, shape, maxval, rejects):135"""Rejection sample indices.136
137Samples integers from a given interval [0, maxval] while rejecting the values
138that are in rejects.
139
140Args:
141key: Random key.
142shape: Output shape.
143maxval: Maximum allowed index value.
144rejects: Indices to reject.
145
146Returns:
147samples: Sampled indices.
148"""
149in1dvec = jax.vmap(jnp.isin)150
151def cond_fun(carry):152_, _, discard = carry153return jnp.any(discard)154
155def body_fun(carry):156key, samples, _ = carry157key, use_key = random.split(key)158new_samples = random.randint(use_key, shape=shape, minval=0, maxval=maxval)159discard = jnp.logical_or(160in1dvec(new_samples, samples), in1dvec(new_samples, rejects))161samples = jnp.where(discard, samples, new_samples)162return key, samples, in1dvec(samples, rejects)163
164key, use_key = random.split(key)165samples = random.randint(use_key, shape=shape, minval=0, maxval=maxval)166discard = in1dvec(samples, rejects)167_, samples, _ = jax.lax.while_loop(cond_fun, body_fun,168(key, samples, discard))169return samples170
171
172def sample_knn_triplets(key, neighbors, n_inliers, n_outliers):173"""Sample nearest neighbors triplets based on the neighbors.174
175Args:
176key: Random key.
177neighbors: Nearest neighbors indices for each point.
178n_inliers: Number of inliers.
179n_outliers: Number of outliers.
180
181Returns:
182triplets: Sampled triplets.
183"""
184n_points = neighbors.shape[0]185anchors = jnp.tile(186jnp.arange(n_points).reshape([-1, 1]),187[1, n_inliers * n_outliers]).reshape([-1, 1])188inliers = jnp.tile(neighbors[:, 1:n_inliers + 1],189[1, n_outliers]).reshape([-1, 1])190outliers = rejection_sample(key, (n_points, n_inliers * n_outliers), n_points,191neighbors).reshape([-1, 1])192triplets = jnp.concatenate((anchors, inliers, outliers), 1)193return triplets194
195
196def sample_random_triplets(key, inputs, n_random, distance_fn, sig):197"""Sample uniformly random triplets.198
199Args:
200key: Random key.
201inputs: Input points.
202n_random: Number of random triplets per point.
203distance_fn: Distance function.
204sig: Scaling factor for the distances
205
206Returns:
207triplets: Sampled triplets.
208"""
209n_points = inputs.shape[0]210anchors = jnp.tile(jnp.arange(n_points).reshape([-1, 1]),211[1, n_random]).reshape([-1, 1])212pairs = rejection_sample(key, (n_points * n_random, 2), n_points, anchors)213triplets = jnp.concatenate((anchors, pairs), 1)214anc = triplets[:, 0]215sim = triplets[:, 1]216out = triplets[:, 2]217p_sim = -(sliced_distances(anc, sim, inputs, distance_fn)**2) / (218sig[anc] * sig[sim])219p_out = -(sliced_distances(anc, out, inputs, distance_fn)**2) / (220sig[anc] * sig[out])221flip = p_sim < p_out222weights = p_sim - p_out223pairs = jnp.where(224jnp.tile(flip.reshape([-1, 1]), [1, 2]), jnp.fliplr(pairs), pairs)225triplets = jnp.concatenate((anchors, pairs), 1)226return triplets, weights227
228
229def find_scaled_neighbors(inputs, neighbors, distance_fn):230"""Calculates the scaled neighbors and their similarities.231
232Args:
233inputs: Input examples.
234neighbors: Nearest neighbors
235distance_fn: Distance function.
236
237Returns:
238Scaled distances and neighbors, and the scale parameter.
239"""
240n_points, n_neighbors = neighbors.shape241anchors = jnp.tile(jnp.arange(n_points).reshape([-1, 1]),242[1, n_neighbors]).flatten()243hits = neighbors.flatten()244distances = sliced_distances(anchors, hits, inputs, distance_fn)**2245distances = distances.reshape([n_points, -1])246sig = jnp.maximum(jnp.mean(jnp.sqrt(distances[:, 3:6]), axis=1), 1e-10)247scaled_distances = distances / (sig.reshape([-1, 1]) * sig[neighbors])248sort_indices = jnp.argsort(scaled_distances, 1)249scaled_distances = jnp.take_along_axis(scaled_distances, sort_indices, 1)250sorted_neighbors = jnp.take_along_axis(neighbors, sort_indices, 1)251return scaled_distances, sorted_neighbors, sig252
253
254def find_triplet_weights(inputs,255triplets,256neighbors,257distance_fn,258sig,259distances=None):260"""Calculates the weights for the sampled nearest neighbors triplets.261
262Args:
263inputs: Input points.
264triplets: Nearest neighbor triplets.
265neighbors: Nearest neighbors.
266distance_fn: Distance function.
267sig: Scaling factor for the distances
268distances: Nearest neighbor distances.
269
270Returns:
271weights: Triplet weights.
272"""
273n_points, n_inliers = neighbors.shape274if distances is None:275anchs = jnp.tile(jnp.arange(n_points).reshape([-1, 1]),276[1, n_inliers]).flatten()277inliers = neighbors.flatten()278distances = sliced_distances(anchs, inliers, inputs, distance_fn)**2279p_sim = -distances / (sig[anchs] * sig[inliers])280else:281p_sim = -distances.flatten()282n_outliers = triplets.shape[0] // (n_points * n_inliers)283p_sim = jnp.tile(p_sim.reshape([n_points, n_inliers]),284[1, n_outliers]).flatten()285out_distances = sliced_distances(triplets[:, 0], triplets[:, 2], inputs,286distance_fn)**2287p_out = -out_distances / (sig[triplets[:, 0]] * sig[triplets[:, 2]])288weights = p_sim - p_out289return weights290
291
292def generate_triplets(key,293inputs,294n_inliers,295n_outliers,296n_random,297weight_temp=0.5,298distance='euclidean',299verbose=False):300"""Generate triplets.301
302Args:
303key: Random key.
304inputs: Input points.
305n_inliers: Number of inliers.
306n_outliers: Number of outliers.
307n_random: Number of random triplets per point.
308weight_temp: Temperature of the log transformation on the weights.
309distance: Distance type.
310verbose: Whether to print progress.
311
312Returns:
313triplets and weights
314"""
315n_points = inputs.shape[0]316n_extra = min(n_inliers + 50, n_points)317index = pynndescent.NNDescent(inputs, metric=distance)318index.prepare()319neighbors = index.query(inputs, n_extra)[0]320neighbors = np.concatenate((np.arange(n_points).reshape([-1, 1]), neighbors),3211)322if verbose:323logging.info('found nearest neighbors')324distance_fn = get_distance_fn(distance)325# conpute scaled neighbors and the scale parameter326knn_distances, neighbors, sig = find_scaled_neighbors(inputs, neighbors,327distance_fn)328neighbors = neighbors[:, :n_inliers + 1]329knn_distances = knn_distances[:, :n_inliers + 1]330key, use_key = random.split(key)331triplets = sample_knn_triplets(use_key, neighbors, n_inliers, n_outliers)332weights = find_triplet_weights(333inputs,334triplets,335neighbors[:, 1:n_inliers + 1],336distance_fn,337sig,338distances=knn_distances[:, 1:n_inliers + 1])339flip = weights < 0340anchors, pairs = triplets[:, 0].reshape([-1, 1]), triplets[:, 1:]341pairs = jnp.where(342jnp.tile(flip.reshape([-1, 1]), [1, 2]), jnp.fliplr(pairs), pairs)343triplets = jnp.concatenate((anchors, pairs), 1)344
345if n_random > 0:346key, use_key = random.split(key)347rand_triplets, rand_weights = sample_random_triplets(348use_key, inputs, n_random, distance_fn, sig)349
350triplets = jnp.concatenate((triplets, rand_triplets), 0)351weights = jnp.concatenate((weights, 0.1 * rand_weights))352
353weights -= jnp.min(weights)354weights = tempered_log(1. + weights, weight_temp)355return triplets, weights356
357
358@jax.jit359def update_embedding_dbd(embedding, grad, vel, gain, lr, iter_num):360"""Update the embedding using delta-bar-delta."""361gamma = jnp.where(iter_num > _SWITCH_ITER, _FINAL_MOMENTUM, _INIT_MOMENTUM)362gain = jnp.where(363jnp.sign(vel) != jnp.sign(grad), gain + _INCREASE_GAIN,364jnp.maximum(gain * _DAMP_GAIN, _MIN_GAIN))365vel = gamma * vel - lr * gain * grad366embedding += vel367return embedding, gain, vel368
369
370@jax.jit371def trimap_metrics(embedding, triplets, weights):372"""Return trimap loss and number of violated triplets."""373anc_points = embedding[triplets[:, 0]]374sim_points = embedding[triplets[:, 1]]375out_points = embedding[triplets[:, 2]]376sim_distance = 1. + squared_euclidean_dist(anc_points, sim_points)377out_distance = 1. + squared_euclidean_dist(anc_points, out_points)378num_violated = jnp.sum(sim_distance > out_distance)379loss = jnp.mean(weights * 1. / (1. + out_distance / sim_distance))380return loss, num_violated381
382
383@jax.jit384def trimap_loss(embedding, triplets, weights):385"""Return trimap loss."""386loss, _ = trimap_metrics(embedding, triplets, weights)387return loss388
389
390def transform(key,391inputs,392n_dims=2,393n_inliers=10,394n_outliers=5,395n_random=3,396weight_temp=0.5,397distance='euclidean',398lr=0.1,399n_iters=400,400init_embedding='pca',401apply_pca=True,402triplets=None,403weights=None,404verbose=False):405"""Transform inputs using TriMap.406
407Args:
408key: Random key.
409inputs: Input points.
410n_dims: Number of output dimension.
411n_inliers: Number of inliers.
412n_outliers: Number of outliers.
413n_random: Number of random triplets per point.
414weight_temp: Temperature of the log transformation on the weights.
415distance: Distance type.
416lr: Learning rate.
417n_iters: Number of iterations.
418init_embedding: Initial embedding: pca, random, or pass pre-computed.
419apply_pca: Apply PCA to reduce the dimension for knn search.
420triplets: Use pre-sampled triplets.
421weights: Use pre-computed weights.
422verbose: Whether to print progress.
423
424Returns:
425embedding
426"""
427
428if verbose:429t = time.time()430n_points, dim = inputs.shape431assert n_inliers < n_points - 1, (432'n_inliers must be less than (number of data points - 1).')433if verbose:434logging.info('running TriMap on %d points with dimension %d', n_points, dim)435pca_solution = False436if triplets is None:437if verbose:438logging.info('pre-processing')439if distance != 'hamming':440if dim > _DIM_PCA and apply_pca:441inputs -= np.mean(inputs, axis=0)442inputs = TruncatedSVD(443n_components=_DIM_PCA, random_state=0).fit_transform(inputs)444pca_solution = True445if verbose:446logging.info('applied PCA')447else:448inputs -= np.min(inputs)449inputs /= np.max(inputs)450inputs -= np.mean(inputs, axis=0)451key, use_key = random.split(key)452triplets, weights = generate_triplets(453key,454inputs,455n_inliers,456n_outliers,457n_random,458weight_temp=weight_temp,459distance=distance,460verbose=verbose)461if verbose:462logging.info('sampled triplets')463else:464if verbose:465logging.info('using pre-computed triplets')466
467if isinstance(init_embedding, str):468if init_embedding == 'pca':469if pca_solution:470embedding = jnp.array(_INIT_SCALE * inputs[:, :n_dims])471else:472embedding = jnp.array(473_INIT_SCALE *474PCA(n_components=n_dims).fit_transform(inputs).astype(np.float32))475elif init_embedding == 'random':476key, use_key = random.split(key)477embedding = random.normal(478use_key, shape=[n_points, n_dims], dtype=jnp.float32) * _INIT_SCALE479else:480embedding = jnp.array(init_embedding, dtype=jnp.float32)481
482n_triplets = float(triplets.shape[0])483lr = lr * n_points / n_triplets484if verbose:485logging.info('running TriMap using DBD')486vel = jnp.zeros_like(embedding, dtype=jnp.float32)487gain = jnp.ones_like(embedding, dtype=jnp.float32)488
489trimap_grad = jax.jit(jax.grad(trimap_loss))490
491for itr in range(n_iters):492gamma = _FINAL_MOMENTUM if itr > _SWITCH_ITER else _INIT_MOMENTUM493grad = trimap_grad(embedding + gamma * vel, triplets, weights)494
495# update the embedding496embedding, vel, gain = update_embedding_dbd(embedding, grad, vel, gain, lr,497itr)498if verbose:499if (itr + 1) % _DISPLAY_ITER == 0:500loss, n_violated = trimap_metrics(embedding, triplets, weights)501logging.info(502'Iteration: %4d / %4d, Loss: %3.3f, Violated triplets: %0.4f',503itr + 1, n_iters, loss, n_violated / n_triplets * 100.0)504if verbose:505elapsed = str(datetime.timedelta(seconds=time.time() - t))506logging.info('Elapsed time: %s', elapsed)507return embedding508