google-research

Форк
0
272 строки · 9.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contains functions used for creating pairs from labeled and unlabeled data (currently used only for the siamese network)."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21
import collections
22
import random
23
import numpy as np
24
from sklearn import metrics
25
from sklearn.neighbors import NearestNeighbors
26

27

28
def get_choices(arr, num_choices, valid_range, not_arr=None, replace=False):
29
  """Select n=num_choices choices from arr, with the following constraints.
30

31
  Args:
32
    arr: if arr is an integer, the pool of choices is interpreted as [0, arr]
33
    num_choices: number of choices
34
    valid_range: choice > valid_range[0] and choice < valid_range[1]
35
    not_arr: choice not in not_arr
36
    replace: if True, draw choices with replacement
37

38
  Returns:
39
    choices.
40
  """
41
  if not_arr is None:
42
    not_arr = []
43
  if isinstance(valid_range, int):
44
    valid_range = [0, valid_range]
45
  # make sure we have enough valid points in arr
46
  if isinstance(arr, tuple):
47
    if min(arr[1], valid_range[1]) - max(arr[0], valid_range[0]) < num_choices:
48
      raise ValueError('Not enough elements in arr are outside of valid_range!')
49
    n_arr = arr[1]
50
    arr0 = arr[0]
51
    arr = collections.defaultdict(lambda: -1)
52
    get_arr = lambda x: x
53
    replace = True
54
  else:
55
    greater_than = np.array(arr) > valid_range[0]
56
    less_than = np.array(arr) < valid_range[1]
57
    if np.sum(np.logical_and(greater_than, less_than)) < num_choices:
58
      raise ValueError('Not enough elements in arr are outside of valid_range!')
59
    # make a copy of arr, since we'll be editing the array
60
    n_arr = len(arr)
61
    arr0 = 0
62
    arr = np.array(arr, copy=True)
63
    get_arr = lambda x: arr[x]
64
  not_arr_set = set(not_arr)
65

66
  def get_choice():
67
    arr_idx = random.randint(arr0, n_arr - 1)
68
    while get_arr(arr_idx) in not_arr_set:
69
      arr_idx = random.randint(arr0, n_arr - 1)
70
    return arr_idx
71

72
  if isinstance(not_arr, int):
73
    not_arr = list(not_arr)
74
  choices = []
75
  for _ in range(num_choices):
76
    arr_idx = get_choice()
77
    while get_arr(arr_idx) <= valid_range[0] or get_arr(
78
        arr_idx) >= valid_range[1]:
79
      arr_idx = get_choice()
80
    choices.append(int(get_arr(arr_idx)))
81
    if not replace:
82
      arr[arr_idx], arr[n_arr - 1] = arr[n_arr - 1], arr[arr_idx]
83
      n_arr -= 1
84
  return choices
85

86

87
def create_pairs_from_labeled_data(x, digit_indices, use_classes=None):
88
  """Positive and negative pair creation from labeled data.
89

90
  Alternates between positive and negative pairs.
91

92
  Args:
93
    x: labeled data
94
    digit_indices:  nested array of depth 2 (in other words a jagged matrix),
95
      where row i contains the indices in x of all examples labeled with class i
96
    use_classes:    in cases where we only want pairs from a subset of the
97
      classes, use_classes is a list of the classes to draw pairs from, else it
98
      is None
99

100
  Returns:
101
    pairs: positive and negative pairs
102
    labels: corresponding labels
103
  """
104
  n_clusters = len(digit_indices)
105
  if use_classes is None:
106
    use_classes = list(range(n_clusters))
107

108
  pairs = []
109
  labels = []
110
  n = min([len(digit_indices[d]) for d in range(n_clusters)]) - 1
111
  for d in use_classes:
112
    for i in range(n):
113
      z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
114
      pairs += [[x[z1], x[z2]]]
115
      inc = random.randrange(1, n_clusters)
116
      dn = (d + inc) % n_clusters
117
      z1, z2 = digit_indices[d][i], digit_indices[dn][i]
118
      pairs += [[x[z1], x[z2]]]
119
      labels += [1, 0]
120
  pairs = np.array(pairs).reshape((len(pairs), 2) + x.shape[1:])
121
  labels = np.array(labels)
122
  return pairs, labels
123

124

125
def create_pairs_from_unlabeled_data(x1,
126
                                     x2=None,
127
                                     y=None,
128
                                     p=None,
129
                                     k=5,
130
                                     tot_pairs=None,
131
                                     pre_shuffled=False,
132
                                     verbose=None):
133
  """Generates positive and negative pairs for the siamese network from unlabeled data.
134

135
  Draws from the k nearest neighbors (where k is the
136
  provided parameter) of each point to form pairs. Number of neighbors
137
  to draw is determined by tot_pairs, if provided, or k if not provided.
138

139
  Args:
140
    x1: input data array
141
    x2: parallel data array (pairs will exactly shadow the indices of x1, but be
142
      drawn from x2)
143
    y:  true labels (if available) purely for checking how good our pairs are
144
    p:  permutation vector - in cases where the array is shuffled and we use a
145
      precomputed knn matrix (where knn is performed on unshuffled data), we
146
      keep track of the permutations with p, and apply the same permutation to
147
      the precomputed knn matrix
148
    k:  the number of neighbors to use (the 'k' in knn)
149
    tot_pairs: total number of pairs to produce words, an approximation of KNN
150
    pre_shuffled: pre shuffled or not
151
    verbose: flag for extra debugging printouts
152

153
  Returns:
154
    pairs for x1, (pairs for x2 if x2 is provided), labels
155
    (inferred by knn), (labels_true, the absolute truth, if y
156
    is provided
157
  """
158
  if x2 is not None and x1.shape != x2.shape:
159
    raise ValueError('x1 and x2 must be the same shape!')
160

161
  n = len(p) if p is not None else len(x1)
162

163
  pairs_per_pt = max(1, min(k, int(
164
      tot_pairs / (n * 2)))) if tot_pairs is not None else max(1, k)
165

166
  if p is not None and not pre_shuffled:
167
    x1 = x1[p[:n]]
168
    y = y[p[:n]]
169

170
  pairs = []
171
  pairs2 = []
172
  labels = []
173
  true = []
174
  verbose = True
175

176
  if verbose:
177
    print('computing k={} nearest neighbors...'.format(k))
178
  if len(x1.shape) > 2:
179
    x1_flat = x1.reshape(x1.shape[0], np.prod(x1.shape[1:]))[:n]
180
  else:
181
    x1_flat = x1[:n]
182
  print('I am hereee', x1_flat.shape)
183
  nbrs = NearestNeighbors(n_neighbors=k + 1).fit(x1_flat)
184
  print('NearestNeighbors')
185
  _, idx = nbrs.kneighbors(x1_flat)
186
  print('NearestNeighbors2')
187
  # for each row, remove the element itself from its list of neighbors
188
  # (we don't care that each point is its own closest neighbor)
189
  new_idx = np.empty((idx.shape[0], idx.shape[1] - 1))
190
  print('replace')
191
  assert (idx >= 0).all()
192
  print('I am hereee', idx.shape[0])
193
  for i in range(idx.shape[0]):
194
    try:
195
      new_idx[i] = idx[i, idx[i] != i][:idx.shape[1] - 1]
196
    except Exception as e:
197
      print(idx[i, Ellipsis], new_idx.shape, idx.shape)
198
      raise e
199
  idx = new_idx.astype(int)
200
  k_max = min(idx.shape[1], k + 1)
201

202
  if verbose:
203
    print('creating pairs...')
204
    print('ks', n, k_max, k, pairs_per_pt)
205

206
  # pair generation loop (alternates between true and false pairs)
207
  consecutive_fails = 0
208
  for i in range(n):
209
    # get_choices sometimes fails with precomputed results. if this happens
210
    # too often, we relax the constraint on k
211
    if consecutive_fails > 5:
212
      k_max = min(idx.shape[1], int(k_max * 2))
213
      consecutive_fails = 0
214
    if verbose and i % 10000 == 0:
215
      print('Iter: {}/{}'.format(i, n))
216
    # pick points from neighbors of i for positive pairs
217
    try:
218
      choices = get_choices(
219
          idx[i, :k_max], pairs_per_pt, valid_range=[-1, np.inf], replace=False)
220
      consecutive_fails = 0
221
    except ValueError:
222
      consecutive_fails += 1
223
      continue
224
    assert i not in choices
225
    # form the pairs
226
    new_pos = [[x1[i], x1[c]] for c in choices]
227
    if x2 is not None:
228
      new_pos2 = [[x2[i], x2[c]] for c in choices]
229
    if y is not None:
230
      pos_labels = [[y[i] == y[c]] for c in choices]
231
    # pick points *not* in neighbors of i for negative pairs
232
    try:
233
      choices = get_choices((0, n),
234
                            pairs_per_pt,
235
                            valid_range=[-1, np.inf],
236
                            not_arr=idx[i, :k_max],
237
                            replace=False)
238
      consecutive_fails = 0
239
    except ValueError:
240
      consecutive_fails += 1
241
      continue
242
    # form negative pairs
243
    new_neg = [[x1[i], x1[c]] for c in choices]
244
    if x2 is not None:
245
      new_neg2 = [[x2[i], x2[c]] for c in choices]
246
    if y is not None:
247
      neg_labels = [[y[i] == y[c]] for c in choices]
248

249
    # add pairs to our list
250
    labels += [1] * len(new_pos) + [0] * len(new_neg)
251
    pairs += new_pos + new_neg
252
    if x2 is not None:
253
      pairs2 += new_pos2 + new_neg2
254
    if y is not None:
255
      true += pos_labels + neg_labels
256

257
  # package return parameters for output
258
  ret = [np.array(pairs).reshape((len(pairs), 2) + x1.shape[1:])]
259
  if x2 is not None:
260
    ret.append(np.array(pairs2).reshape((len(pairs2), 2) + x2.shape[1:]))
261
  ret.append(np.array(labels))
262
  if y is not None:
263
    true = np.array(true).astype(int).reshape(-1, 1)
264
    if verbose:
265
      # if true vectors are provided, we can take a peek to check
266
      # the validity of our kNN approximation
267
      print('confusion matrix for pairs and approximated labels:')
268
      print(metrics.confusion_matrix(true, labels) / true.shape[0])
269
      print(metrics.confusion_matrix(true, labels))
270
    ret.append(true)
271

272
  return ret
273

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.