google-research
202 строки · 6.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contains all functions to construct affinity graph for CNC and siamese nets."""
17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20
21import numpy as np22import tensorflow.compat.v1 as tf23from tensorflow.compat.v1.keras import backend as K24
25
26def squared_distance(input_x, input_y=None, weight=None):27"""Calculates the pairwise distance between points in X and Y.28
29Args:
30input_x: n x d matrix
31input_y: m x d matrix
32weight: affinity n x m -- if provided, we normalize the distance
33
34Returns:
35n x m matrix of all pairwise squared Euclidean distances
36"""
37if input_y is None:38input_y = input_x39sum_dimensions = list(range(2, K.ndim(input_x) + 1))40input_x = K.expand_dims(input_x, axis=1)41if weight is not None:42# if weight provided, we normalize input_x and input_y by weight43d_diag = K.expand_dims(K.sqrt(K.sum(weight, axis=1)), axis=1)44input_x /= d_diag45input_y /= d_diag46squared_difference = K.square(input_x - input_y)47distance = K.sum(squared_difference, axis=sum_dimensions)48return distance49
50
51def knn_affinity(input_x,52n_nbrs,53scale=None,54scale_nbr=None,55local_scale=None,56verbose=False):57"""Calculates Gaussian affinity matrix.58
59Calculates the symmetrized Gaussian affinity matrix with k1 nonzero
60affinities for each point, scaled by
611) a provided scale,
622) the median distance of the k2-th neighbor of each point in X, or
633) a covariance matrix S where S_ii is the distance of the k2-th
64neighbor of each point i, and S_ij = 0 for all i != j
65Here, k1 = n_nbrs, k2 = scale_nbr
66
67Args:
68input_x: input dataset of size n
69n_nbrs: k1
70scale: provided scale
71scale_nbr: k2, used if scale not provided
72local_scale: if True, then we use the aforementioned option 3), else we
73use option 2)
74verbose: extra printouts
75
76Returns:
77n x n affinity matrix
78"""
79if isinstance(n_nbrs, float):80n_nbrs = int(n_nbrs)81elif isinstance(n_nbrs,82tf.Variable) and n_nbrs.dtype.as_numpy_dtype != np.int32:83n_nbrs = tf.cast(n_nbrs, np.int32)84# get squared distance85dist_x = squared_distance(input_x)86# calculate the top k losest neighbors87nn = tf.nn.top_k(-dist_x, n_nbrs, sorted=True)88
89vals = nn[0]90# apply scale91if scale is None:92# if scale not provided, use local scale93if scale_nbr is None:94scale_nbr = 095else:96assert scale_nbr > 0 and scale_nbr <= n_nbrs97if local_scale:98scale = -nn[0][:, scale_nbr - 1]99scale = tf.reshape(scale, [-1, 1])100scale = tf.tile(scale, [1, n_nbrs])101scale = tf.reshape(scale, [-1, 1])102vals = tf.reshape(vals, [-1, 1])103if verbose:104vals = tf.Print(vals, [tf.shape(vals), tf.shape(scale)],105'vals, scale shape')106vals = vals / (2 * scale)107vals = tf.reshape(vals, [-1, n_nbrs])108else:109
110def get_median(scales, m):111with tf.device('/cpu:0'):112scales = tf.nn.top_k(scales, m)[0]113scale = scales[m - 1]114return scale, scales115
116scales = -vals[:, scale_nbr - 1]117const = tf.shape(input_x)[0] // 2118scale, scales = get_median(scales, const)119vals = vals / (2 * scale)120else:121# otherwise, use provided value for global scale122vals = vals / (2 * scale**2)123
124# get the affinity125aff_vals = tf.exp(vals)126# flatten this into a single vector of values to shove in a sparse matrix127aff_vals = tf.reshape(aff_vals, [-1])128# get the matrix of indices corresponding to each rank129# with 1 in the first column and k in the kth column130nn_ind = nn[1]131# get the j index for the sparse matrix132j_index = tf.reshape(nn_ind, [-1, 1])133# the i index is just sequential to the j matrix134i_index = tf.range(tf.shape(nn_ind)[0])135i_index = tf.reshape(i_index, [-1, 1])136i_index = tf.tile(i_index, [1, tf.shape(nn_ind)[1]])137i_index = tf.reshape(i_index, [-1, 1])138# concatenate the indices to build the sparse matrix139indices = tf.concat((i_index, j_index), axis=1)140# assemble the sparse weight matrix141weight_mat = tf.SparseTensor(142indices=tf.cast(indices, dtype='int64'),143values=aff_vals,144dense_shape=tf.cast(tf.shape(dist_x), dtype='int64'))145# fix the ordering of the indices146weight_mat = tf.sparse_reorder(weight_mat)147# convert to dense tensor148weight_mat = tf.sparse_tensor_to_dense(weight_mat)149# symmetrize150weight_mat = (weight_mat + tf.transpose(weight_mat)) / 2.0151
152return weight_mat153
154
155def full_affinity(input_x, scale):156"""Calculates the symmetrized full Gaussian affinity matrix, scaled by a provided scale.157
158Args:
159input_x: input dataset of size n x d
160scale: provided scale
161
162Returns:
163n x n affinity matrix
164"""
165sigma = K.variable(scale)166dist_x = squared_distance(input_x)167sigma_squared = K.expand_dims(K.pow(sigma, 2), -1)168weight_mat = K.exp(-dist_x / (2 * sigma_squared))169return weight_mat170
171
172def get_contrastive_loss(m_neg=1, m_pos=.2):173"""Contrastive loss from Hadsell-et-al.'06.174
175http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf.
176
177Args:
178m_neg: negativeness.
179m_pos: possitiveness.
180
181Returns:
182Contrastive loss
183"""
184
185def contrastive_loss(y_true, y_pred):186return K.mean(y_true * K.square(K.maximum(y_pred - m_pos, 0)) +187(1 - y_true) * K.square(K.maximum(m_neg - y_pred, 0)))188
189return contrastive_loss190
191
192def euclidean_distance(vects):193"""Computes the euclidean distances between vects[0] and vects[1]."""194x, y = vects195return K.sqrt(196K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon()))197
198
199def eucl_dist_output_shape(shapes):200"""Provides the output shape of the above computation."""201s_1, _ = shapes202return (s_1[0], 1)203