google-research
272 строки · 9.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contains functions used for creating pairs from labeled and unlabeled data (currently used only for the siamese network)."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21import collections22import random23import numpy as np24from sklearn import metrics25from sklearn.neighbors import NearestNeighbors26
27
28def get_choices(arr, num_choices, valid_range, not_arr=None, replace=False):29"""Select n=num_choices choices from arr, with the following constraints.30
31Args:
32arr: if arr is an integer, the pool of choices is interpreted as [0, arr]
33num_choices: number of choices
34valid_range: choice > valid_range[0] and choice < valid_range[1]
35not_arr: choice not in not_arr
36replace: if True, draw choices with replacement
37
38Returns:
39choices.
40"""
41if not_arr is None:42not_arr = []43if isinstance(valid_range, int):44valid_range = [0, valid_range]45# make sure we have enough valid points in arr46if isinstance(arr, tuple):47if min(arr[1], valid_range[1]) - max(arr[0], valid_range[0]) < num_choices:48raise ValueError('Not enough elements in arr are outside of valid_range!')49n_arr = arr[1]50arr0 = arr[0]51arr = collections.defaultdict(lambda: -1)52get_arr = lambda x: x53replace = True54else:55greater_than = np.array(arr) > valid_range[0]56less_than = np.array(arr) < valid_range[1]57if np.sum(np.logical_and(greater_than, less_than)) < num_choices:58raise ValueError('Not enough elements in arr are outside of valid_range!')59# make a copy of arr, since we'll be editing the array60n_arr = len(arr)61arr0 = 062arr = np.array(arr, copy=True)63get_arr = lambda x: arr[x]64not_arr_set = set(not_arr)65
66def get_choice():67arr_idx = random.randint(arr0, n_arr - 1)68while get_arr(arr_idx) in not_arr_set:69arr_idx = random.randint(arr0, n_arr - 1)70return arr_idx71
72if isinstance(not_arr, int):73not_arr = list(not_arr)74choices = []75for _ in range(num_choices):76arr_idx = get_choice()77while get_arr(arr_idx) <= valid_range[0] or get_arr(78arr_idx) >= valid_range[1]:79arr_idx = get_choice()80choices.append(int(get_arr(arr_idx)))81if not replace:82arr[arr_idx], arr[n_arr - 1] = arr[n_arr - 1], arr[arr_idx]83n_arr -= 184return choices85
86
87def create_pairs_from_labeled_data(x, digit_indices, use_classes=None):88"""Positive and negative pair creation from labeled data.89
90Alternates between positive and negative pairs.
91
92Args:
93x: labeled data
94digit_indices: nested array of depth 2 (in other words a jagged matrix),
95where row i contains the indices in x of all examples labeled with class i
96use_classes: in cases where we only want pairs from a subset of the
97classes, use_classes is a list of the classes to draw pairs from, else it
98is None
99
100Returns:
101pairs: positive and negative pairs
102labels: corresponding labels
103"""
104n_clusters = len(digit_indices)105if use_classes is None:106use_classes = list(range(n_clusters))107
108pairs = []109labels = []110n = min([len(digit_indices[d]) for d in range(n_clusters)]) - 1111for d in use_classes:112for i in range(n):113z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]114pairs += [[x[z1], x[z2]]]115inc = random.randrange(1, n_clusters)116dn = (d + inc) % n_clusters117z1, z2 = digit_indices[d][i], digit_indices[dn][i]118pairs += [[x[z1], x[z2]]]119labels += [1, 0]120pairs = np.array(pairs).reshape((len(pairs), 2) + x.shape[1:])121labels = np.array(labels)122return pairs, labels123
124
125def create_pairs_from_unlabeled_data(x1,126x2=None,127y=None,128p=None,129k=5,130tot_pairs=None,131pre_shuffled=False,132verbose=None):133"""Generates positive and negative pairs for the siamese network from unlabeled data.134
135Draws from the k nearest neighbors (where k is the
136provided parameter) of each point to form pairs. Number of neighbors
137to draw is determined by tot_pairs, if provided, or k if not provided.
138
139Args:
140x1: input data array
141x2: parallel data array (pairs will exactly shadow the indices of x1, but be
142drawn from x2)
143y: true labels (if available) purely for checking how good our pairs are
144p: permutation vector - in cases where the array is shuffled and we use a
145precomputed knn matrix (where knn is performed on unshuffled data), we
146keep track of the permutations with p, and apply the same permutation to
147the precomputed knn matrix
148k: the number of neighbors to use (the 'k' in knn)
149tot_pairs: total number of pairs to produce words, an approximation of KNN
150pre_shuffled: pre shuffled or not
151verbose: flag for extra debugging printouts
152
153Returns:
154pairs for x1, (pairs for x2 if x2 is provided), labels
155(inferred by knn), (labels_true, the absolute truth, if y
156is provided
157"""
158if x2 is not None and x1.shape != x2.shape:159raise ValueError('x1 and x2 must be the same shape!')160
161n = len(p) if p is not None else len(x1)162
163pairs_per_pt = max(1, min(k, int(164tot_pairs / (n * 2)))) if tot_pairs is not None else max(1, k)165
166if p is not None and not pre_shuffled:167x1 = x1[p[:n]]168y = y[p[:n]]169
170pairs = []171pairs2 = []172labels = []173true = []174verbose = True175
176if verbose:177print('computing k={} nearest neighbors...'.format(k))178if len(x1.shape) > 2:179x1_flat = x1.reshape(x1.shape[0], np.prod(x1.shape[1:]))[:n]180else:181x1_flat = x1[:n]182print('I am hereee', x1_flat.shape)183nbrs = NearestNeighbors(n_neighbors=k + 1).fit(x1_flat)184print('NearestNeighbors')185_, idx = nbrs.kneighbors(x1_flat)186print('NearestNeighbors2')187# for each row, remove the element itself from its list of neighbors188# (we don't care that each point is its own closest neighbor)189new_idx = np.empty((idx.shape[0], idx.shape[1] - 1))190print('replace')191assert (idx >= 0).all()192print('I am hereee', idx.shape[0])193for i in range(idx.shape[0]):194try:195new_idx[i] = idx[i, idx[i] != i][:idx.shape[1] - 1]196except Exception as e:197print(idx[i, Ellipsis], new_idx.shape, idx.shape)198raise e199idx = new_idx.astype(int)200k_max = min(idx.shape[1], k + 1)201
202if verbose:203print('creating pairs...')204print('ks', n, k_max, k, pairs_per_pt)205
206# pair generation loop (alternates between true and false pairs)207consecutive_fails = 0208for i in range(n):209# get_choices sometimes fails with precomputed results. if this happens210# too often, we relax the constraint on k211if consecutive_fails > 5:212k_max = min(idx.shape[1], int(k_max * 2))213consecutive_fails = 0214if verbose and i % 10000 == 0:215print('Iter: {}/{}'.format(i, n))216# pick points from neighbors of i for positive pairs217try:218choices = get_choices(219idx[i, :k_max], pairs_per_pt, valid_range=[-1, np.inf], replace=False)220consecutive_fails = 0221except ValueError:222consecutive_fails += 1223continue224assert i not in choices225# form the pairs226new_pos = [[x1[i], x1[c]] for c in choices]227if x2 is not None:228new_pos2 = [[x2[i], x2[c]] for c in choices]229if y is not None:230pos_labels = [[y[i] == y[c]] for c in choices]231# pick points *not* in neighbors of i for negative pairs232try:233choices = get_choices((0, n),234pairs_per_pt,235valid_range=[-1, np.inf],236not_arr=idx[i, :k_max],237replace=False)238consecutive_fails = 0239except ValueError:240consecutive_fails += 1241continue242# form negative pairs243new_neg = [[x1[i], x1[c]] for c in choices]244if x2 is not None:245new_neg2 = [[x2[i], x2[c]] for c in choices]246if y is not None:247neg_labels = [[y[i] == y[c]] for c in choices]248
249# add pairs to our list250labels += [1] * len(new_pos) + [0] * len(new_neg)251pairs += new_pos + new_neg252if x2 is not None:253pairs2 += new_pos2 + new_neg2254if y is not None:255true += pos_labels + neg_labels256
257# package return parameters for output258ret = [np.array(pairs).reshape((len(pairs), 2) + x1.shape[1:])]259if x2 is not None:260ret.append(np.array(pairs2).reshape((len(pairs2), 2) + x2.shape[1:]))261ret.append(np.array(labels))262if y is not None:263true = np.array(true).astype(int).reshape(-1, 1)264if verbose:265# if true vectors are provided, we can take a peek to check266# the validity of our kNN approximation267print('confusion matrix for pairs and approximated labels:')268print(metrics.confusion_matrix(true, labels) / true.shape[0])269print(metrics.confusion_matrix(true, labels))270ret.append(true)271
272return ret273