google-research

Форк
0
507 строк · 15.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""TriMap: Large-scale Dimensionality Reduction Using Triplets.
17

18
Source: https://arxiv.org/pdf/1910.00204.pdf
19
"""
20

21
import datetime
22
import time
23
from typing import Mapping
24

25
from absl import logging
26
import jax
27
import jax.numpy as jnp
28
import jax.random as random
29
import numpy as np
30
import pynndescent
31
from sklearn.decomposition import PCA
32
from sklearn.decomposition import TruncatedSVD
33

34
_DIM_PCA = 100
35
_INIT_SCALE = 0.01
36
_INIT_MOMENTUM = 0.5
37
_FINAL_MOMENTUM = 0.8
38
_SWITCH_ITER = 250
39
_MIN_GAIN = 0.01
40
_INCREASE_GAIN = 0.2
41
_DAMP_GAIN = 0.8
42
_DISPLAY_ITER = 100
43

44

45
def tempered_log(x, t):
46
  """Tempered log with temperature t."""
47
  if jnp.abs(t - 1.0) < 1e-5:
48
    return jnp.log(x)
49
  else:
50
    return 1. / (1. - t) * (jnp.power(x, 1.0 - t) - 1.0)
51

52

53
def get_distance_fn(distance_fn_name):
54
  """Get the distance function."""
55
  if distance_fn_name == 'euclidean':
56
    return euclidean_dist
57
  elif distance_fn_name == 'manhattan':
58
    return manhattan_dist
59
  elif distance_fn_name == 'cosine':
60
    return cosine_dist
61
  elif distance_fn_name == 'hamming':
62
    return hamming_dist
63
  elif distance_fn_name == 'chebyshev':
64
    return chebyshev_dist
65
  else:
66
    raise ValueError(f'Distance function {distance_fn_name} not supported.')
67

68

69
def sliced_distances(
70
    indices1,
71
    indices2,
72
    inputs,
73
    distance_fn):
74
  """Applies distance_fn in smaller slices to avoid memory blow-ups.
75

76
  Args:
77
    indices1: First array of indices.
78
    indices2: Second array of indices.
79
    inputs: 2-D array of inputs.
80
    distance_fn: Distance function that applies row-wise.
81

82
  Returns:
83
    Pairwise distances between the row indices in indices1 and indices2.
84
  """
85
  slice_size = inputs.shape[0]
86
  distances = []
87
  num_slices = int(np.ceil(len(indices1) / slice_size))
88
  for slice_id in range(num_slices):
89
    start = slice_id * slice_size
90
    end = (slice_id + 1) * slice_size
91
    distances.append(
92
        distance_fn(inputs[indices1[start:end]], inputs[indices2[start:end]]))
93
  return jnp.concatenate(distances)
94

95

96
@jax.jit
97
def squared_euclidean_dist(x1, x2):
98
  """Squared Euclidean distance between rows of x1 and x2."""
99
  return jnp.sum(jnp.power(x1 - x2, 2), axis=-1)
100

101

102
@jax.jit
103
def euclidean_dist(x1, x2):
104
  """Euclidean distance between rows of x1 and x2."""
105
  return jnp.sqrt(jnp.sum(jnp.power(x1 - x2, 2), axis=-1))
106

107

108
@jax.jit
109
def manhattan_dist(x1, x2):
110
  """Manhattan distance between rows of x1 and x2."""
111
  return jnp.sum(jnp.abs(x1 - x2), axis=-1)
112

113

114
@jax.jit
115
def cosine_dist(x1, x2):
116
  """Cosine (i.e. cosine) between rows of x1 and x2."""
117
  x1_norm = jnp.maximum(jnp.linalg.norm(x1, axis=-1), 1e-20)
118
  x2_norm = jnp.maximum(jnp.linalg.norm(x2, axis=-1), 1e-20)
119
  return 1. - jnp.sum(x1 * x2, -1) / x1_norm / x2_norm
120

121

122
@jax.jit
123
def hamming_dist(x1, x2):
124
  """Hamming distance between two vectors."""
125
  return jnp.sum(x1 != x2, axis=-1)
126

127

128
@jax.jit
129
def chebyshev_dist(x1, x2):
130
  """Chebyshev distance between two vectors."""
131
  return jnp.max(jnp.abs(x1 - x2), -1)
132

133

134
def rejection_sample(key, shape, maxval, rejects):
135
  """Rejection sample indices.
136

137
  Samples integers from a given interval [0, maxval] while rejecting the values
138
  that are in rejects.
139

140
  Args:
141
    key: Random key.
142
    shape: Output shape.
143
    maxval: Maximum allowed index value.
144
    rejects: Indices to reject.
145

146
  Returns:
147
    samples: Sampled indices.
148
  """
149
  in1dvec = jax.vmap(jnp.isin)
150

151
  def cond_fun(carry):
152
    _, _, discard = carry
153
    return jnp.any(discard)
154

155
  def body_fun(carry):
156
    key, samples, _ = carry
157
    key, use_key = random.split(key)
158
    new_samples = random.randint(use_key, shape=shape, minval=0, maxval=maxval)
159
    discard = jnp.logical_or(
160
        in1dvec(new_samples, samples), in1dvec(new_samples, rejects))
161
    samples = jnp.where(discard, samples, new_samples)
162
    return key, samples, in1dvec(samples, rejects)
163

164
  key, use_key = random.split(key)
165
  samples = random.randint(use_key, shape=shape, minval=0, maxval=maxval)
166
  discard = in1dvec(samples, rejects)
167
  _, samples, _ = jax.lax.while_loop(cond_fun, body_fun,
168
                                     (key, samples, discard))
169
  return samples
170

171

172
def sample_knn_triplets(key, neighbors, n_inliers, n_outliers):
173
  """Sample nearest neighbors triplets based on the neighbors.
174

175
  Args:
176
    key: Random key.
177
    neighbors: Nearest neighbors indices for each point.
178
    n_inliers: Number of inliers.
179
    n_outliers: Number of outliers.
180

181
  Returns:
182
    triplets: Sampled triplets.
183
  """
184
  n_points = neighbors.shape[0]
185
  anchors = jnp.tile(
186
      jnp.arange(n_points).reshape([-1, 1]),
187
      [1, n_inliers * n_outliers]).reshape([-1, 1])
188
  inliers = jnp.tile(neighbors[:, 1:n_inliers + 1],
189
                     [1, n_outliers]).reshape([-1, 1])
190
  outliers = rejection_sample(key, (n_points, n_inliers * n_outliers), n_points,
191
                              neighbors).reshape([-1, 1])
192
  triplets = jnp.concatenate((anchors, inliers, outliers), 1)
193
  return triplets
194

195

196
def sample_random_triplets(key, inputs, n_random, distance_fn, sig):
197
  """Sample uniformly random triplets.
198

199
  Args:
200
    key: Random key.
201
    inputs: Input points.
202
    n_random: Number of random triplets per point.
203
    distance_fn: Distance function.
204
    sig: Scaling factor for the distances
205

206
  Returns:
207
    triplets: Sampled triplets.
208
  """
209
  n_points = inputs.shape[0]
210
  anchors = jnp.tile(jnp.arange(n_points).reshape([-1, 1]),
211
                     [1, n_random]).reshape([-1, 1])
212
  pairs = rejection_sample(key, (n_points * n_random, 2), n_points, anchors)
213
  triplets = jnp.concatenate((anchors, pairs), 1)
214
  anc = triplets[:, 0]
215
  sim = triplets[:, 1]
216
  out = triplets[:, 2]
217
  p_sim = -(sliced_distances(anc, sim, inputs, distance_fn)**2) / (
218
      sig[anc] * sig[sim])
219
  p_out = -(sliced_distances(anc, out, inputs, distance_fn)**2) / (
220
      sig[anc] * sig[out])
221
  flip = p_sim < p_out
222
  weights = p_sim - p_out
223
  pairs = jnp.where(
224
      jnp.tile(flip.reshape([-1, 1]), [1, 2]), jnp.fliplr(pairs), pairs)
225
  triplets = jnp.concatenate((anchors, pairs), 1)
226
  return triplets, weights
227

228

229
def find_scaled_neighbors(inputs, neighbors, distance_fn):
230
  """Calculates the scaled neighbors and their similarities.
231

232
  Args:
233
    inputs: Input examples.
234
    neighbors: Nearest neighbors
235
    distance_fn: Distance function.
236

237
  Returns:
238
    Scaled distances and neighbors, and the scale parameter.
239
  """
240
  n_points, n_neighbors = neighbors.shape
241
  anchors = jnp.tile(jnp.arange(n_points).reshape([-1, 1]),
242
                     [1, n_neighbors]).flatten()
243
  hits = neighbors.flatten()
244
  distances = sliced_distances(anchors, hits, inputs, distance_fn)**2
245
  distances = distances.reshape([n_points, -1])
246
  sig = jnp.maximum(jnp.mean(jnp.sqrt(distances[:, 3:6]), axis=1), 1e-10)
247
  scaled_distances = distances / (sig.reshape([-1, 1]) * sig[neighbors])
248
  sort_indices = jnp.argsort(scaled_distances, 1)
249
  scaled_distances = jnp.take_along_axis(scaled_distances, sort_indices, 1)
250
  sorted_neighbors = jnp.take_along_axis(neighbors, sort_indices, 1)
251
  return scaled_distances, sorted_neighbors, sig
252

253

254
def find_triplet_weights(inputs,
255
                         triplets,
256
                         neighbors,
257
                         distance_fn,
258
                         sig,
259
                         distances=None):
260
  """Calculates the weights for the sampled nearest neighbors triplets.
261

262
  Args:
263
    inputs: Input points.
264
    triplets: Nearest neighbor triplets.
265
    neighbors: Nearest neighbors.
266
    distance_fn: Distance function.
267
    sig: Scaling factor for the distances
268
    distances: Nearest neighbor distances.
269

270
  Returns:
271
    weights: Triplet weights.
272
  """
273
  n_points, n_inliers = neighbors.shape
274
  if distances is None:
275
    anchs = jnp.tile(jnp.arange(n_points).reshape([-1, 1]),
276
                     [1, n_inliers]).flatten()
277
    inliers = neighbors.flatten()
278
    distances = sliced_distances(anchs, inliers, inputs, distance_fn)**2
279
    p_sim = -distances / (sig[anchs] * sig[inliers])
280
  else:
281
    p_sim = -distances.flatten()
282
  n_outliers = triplets.shape[0] // (n_points * n_inliers)
283
  p_sim = jnp.tile(p_sim.reshape([n_points, n_inliers]),
284
                   [1, n_outliers]).flatten()
285
  out_distances = sliced_distances(triplets[:, 0], triplets[:, 2], inputs,
286
                                   distance_fn)**2
287
  p_out = -out_distances / (sig[triplets[:, 0]] * sig[triplets[:, 2]])
288
  weights = p_sim - p_out
289
  return weights
290

291

292
def generate_triplets(key,
293
                      inputs,
294
                      n_inliers,
295
                      n_outliers,
296
                      n_random,
297
                      weight_temp=0.5,
298
                      distance='euclidean',
299
                      verbose=False):
300
  """Generate triplets.
301

302
  Args:
303
    key: Random key.
304
    inputs: Input points.
305
    n_inliers: Number of inliers.
306
    n_outliers: Number of outliers.
307
    n_random: Number of random triplets per point.
308
    weight_temp: Temperature of the log transformation on the weights.
309
    distance: Distance type.
310
    verbose: Whether to print progress.
311

312
  Returns:
313
    triplets and weights
314
  """
315
  n_points = inputs.shape[0]
316
  n_extra = min(n_inliers + 50, n_points)
317
  index = pynndescent.NNDescent(inputs, metric=distance)
318
  index.prepare()
319
  neighbors = index.query(inputs, n_extra)[0]
320
  neighbors = np.concatenate((np.arange(n_points).reshape([-1, 1]), neighbors),
321
                             1)
322
  if verbose:
323
    logging.info('found nearest neighbors')
324
  distance_fn = get_distance_fn(distance)
325
  # conpute scaled neighbors and the scale parameter
326
  knn_distances, neighbors, sig = find_scaled_neighbors(inputs, neighbors,
327
                                                        distance_fn)
328
  neighbors = neighbors[:, :n_inliers + 1]
329
  knn_distances = knn_distances[:, :n_inliers + 1]
330
  key, use_key = random.split(key)
331
  triplets = sample_knn_triplets(use_key, neighbors, n_inliers, n_outliers)
332
  weights = find_triplet_weights(
333
      inputs,
334
      triplets,
335
      neighbors[:, 1:n_inliers + 1],
336
      distance_fn,
337
      sig,
338
      distances=knn_distances[:, 1:n_inliers + 1])
339
  flip = weights < 0
340
  anchors, pairs = triplets[:, 0].reshape([-1, 1]), triplets[:, 1:]
341
  pairs = jnp.where(
342
      jnp.tile(flip.reshape([-1, 1]), [1, 2]), jnp.fliplr(pairs), pairs)
343
  triplets = jnp.concatenate((anchors, pairs), 1)
344

345
  if n_random > 0:
346
    key, use_key = random.split(key)
347
    rand_triplets, rand_weights = sample_random_triplets(
348
        use_key, inputs, n_random, distance_fn, sig)
349

350
    triplets = jnp.concatenate((triplets, rand_triplets), 0)
351
    weights = jnp.concatenate((weights, 0.1 * rand_weights))
352

353
  weights -= jnp.min(weights)
354
  weights = tempered_log(1. + weights, weight_temp)
355
  return triplets, weights
356

357

358
@jax.jit
359
def update_embedding_dbd(embedding, grad, vel, gain, lr, iter_num):
360
  """Update the embedding using delta-bar-delta."""
361
  gamma = jnp.where(iter_num > _SWITCH_ITER, _FINAL_MOMENTUM, _INIT_MOMENTUM)
362
  gain = jnp.where(
363
      jnp.sign(vel) != jnp.sign(grad), gain + _INCREASE_GAIN,
364
      jnp.maximum(gain * _DAMP_GAIN, _MIN_GAIN))
365
  vel = gamma * vel - lr * gain * grad
366
  embedding += vel
367
  return embedding, gain, vel
368

369

370
@jax.jit
371
def trimap_metrics(embedding, triplets, weights):
372
  """Return trimap loss and number of violated triplets."""
373
  anc_points = embedding[triplets[:, 0]]
374
  sim_points = embedding[triplets[:, 1]]
375
  out_points = embedding[triplets[:, 2]]
376
  sim_distance = 1. + squared_euclidean_dist(anc_points, sim_points)
377
  out_distance = 1. + squared_euclidean_dist(anc_points, out_points)
378
  num_violated = jnp.sum(sim_distance > out_distance)
379
  loss = jnp.mean(weights * 1. / (1. + out_distance / sim_distance))
380
  return loss, num_violated
381

382

383
@jax.jit
384
def trimap_loss(embedding, triplets, weights):
385
  """Return trimap loss."""
386
  loss, _ = trimap_metrics(embedding, triplets, weights)
387
  return loss
388

389

390
def transform(key,
391
              inputs,
392
              n_dims=2,
393
              n_inliers=10,
394
              n_outliers=5,
395
              n_random=3,
396
              weight_temp=0.5,
397
              distance='euclidean',
398
              lr=0.1,
399
              n_iters=400,
400
              init_embedding='pca',
401
              apply_pca=True,
402
              triplets=None,
403
              weights=None,
404
              verbose=False):
405
  """Transform inputs using TriMap.
406

407
  Args:
408
    key: Random key.
409
    inputs: Input points.
410
    n_dims: Number of output dimension.
411
    n_inliers: Number of inliers.
412
    n_outliers: Number of outliers.
413
    n_random: Number of random triplets per point.
414
    weight_temp: Temperature of the log transformation on the weights.
415
    distance: Distance type.
416
    lr: Learning rate.
417
    n_iters: Number of iterations.
418
    init_embedding: Initial embedding: pca, random, or pass pre-computed.
419
    apply_pca: Apply PCA to reduce the dimension for knn search.
420
    triplets: Use pre-sampled triplets.
421
    weights: Use pre-computed weights.
422
    verbose: Whether to print progress.
423

424
  Returns:
425
    embedding
426
  """
427

428
  if verbose:
429
    t = time.time()
430
  n_points, dim = inputs.shape
431
  assert n_inliers < n_points - 1, (
432
      'n_inliers must be less than (number of data points - 1).')
433
  if verbose:
434
    logging.info('running TriMap on %d points with dimension %d', n_points, dim)
435
  pca_solution = False
436
  if triplets is None:
437
    if verbose:
438
      logging.info('pre-processing')
439
    if distance != 'hamming':
440
      if dim > _DIM_PCA and apply_pca:
441
        inputs -= np.mean(inputs, axis=0)
442
        inputs = TruncatedSVD(
443
            n_components=_DIM_PCA, random_state=0).fit_transform(inputs)
444
        pca_solution = True
445
        if verbose:
446
          logging.info('applied PCA')
447
        else:
448
          inputs -= np.min(inputs)
449
          inputs /= np.max(inputs)
450
          inputs -= np.mean(inputs, axis=0)
451
    key, use_key = random.split(key)
452
    triplets, weights = generate_triplets(
453
        key,
454
        inputs,
455
        n_inliers,
456
        n_outliers,
457
        n_random,
458
        weight_temp=weight_temp,
459
        distance=distance,
460
        verbose=verbose)
461
    if verbose:
462
      logging.info('sampled triplets')
463
  else:
464
    if verbose:
465
      logging.info('using pre-computed triplets')
466

467
  if isinstance(init_embedding, str):
468
    if init_embedding == 'pca':
469
      if pca_solution:
470
        embedding = jnp.array(_INIT_SCALE * inputs[:, :n_dims])
471
      else:
472
        embedding = jnp.array(
473
            _INIT_SCALE *
474
            PCA(n_components=n_dims).fit_transform(inputs).astype(np.float32))
475
    elif init_embedding == 'random':
476
      key, use_key = random.split(key)
477
      embedding = random.normal(
478
          use_key, shape=[n_points, n_dims], dtype=jnp.float32) * _INIT_SCALE
479
  else:
480
    embedding = jnp.array(init_embedding, dtype=jnp.float32)
481

482
  n_triplets = float(triplets.shape[0])
483
  lr = lr * n_points / n_triplets
484
  if verbose:
485
    logging.info('running TriMap using DBD')
486
  vel = jnp.zeros_like(embedding, dtype=jnp.float32)
487
  gain = jnp.ones_like(embedding, dtype=jnp.float32)
488

489
  trimap_grad = jax.jit(jax.grad(trimap_loss))
490

491
  for itr in range(n_iters):
492
    gamma = _FINAL_MOMENTUM if itr > _SWITCH_ITER else _INIT_MOMENTUM
493
    grad = trimap_grad(embedding + gamma * vel, triplets, weights)
494

495
    # update the embedding
496
    embedding, vel, gain = update_embedding_dbd(embedding, grad, vel, gain, lr,
497
                                                itr)
498
    if verbose:
499
      if (itr + 1) % _DISPLAY_ITER == 0:
500
        loss, n_violated = trimap_metrics(embedding, triplets, weights)
501
        logging.info(
502
            'Iteration: %4d / %4d, Loss: %3.3f, Violated triplets: %0.4f',
503
            itr + 1, n_iters, loss, n_violated / n_triplets * 100.0)
504
  if verbose:
505
    elapsed = str(datetime.timedelta(seconds=time.time() - t))
506
    logging.info('Elapsed time: %s', elapsed)
507
  return embedding
508

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.