google-research

Форк
0
328 строк · 10.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contains various utility functions used in the models."""
17
from __future__ import division
18

19
from __future__ import print_function
20
import math
21

22
from munkres import Munkres
23
import numpy as np
24
import sklearn.metrics
25
from sklearn.neighbors import NearestNeighbors
26
import tensorflow.compat.v1 as tf
27
from tensorflow.compat.v1.keras import backend as K
28
from tensorflow.compat.v1.keras.callbacks import Callback
29

30

31
def make_batches(size, batch_size):
32
  """generates a list of tuples for batching data.
33

34
  generates a list of (start_idx, end_idx) tuples for batching data
35
  of the given size and batch_size
36

37
  Args:
38
    size:       size of the data to create batches for
39
    batch_size: batch size
40

41
  Returns:
42
  list of tuples of indices for data
43
  """
44
  num_batches = (size + batch_size - 1) // batch_size  # round up
45
  return [(i * batch_size, min(size, (i + 1) * batch_size))
46
          for i in range(num_batches)]
47

48

49
def train_gen(pairs_train, dist_train, batch_size):
50
  """generator used for training the siamese net with keras.
51

52
  Args:
53
    pairs_train:    training pairs
54
    dist_train:     training labels
55
    batch_size:     batch size
56

57
  Yields:
58
    generator instance
59
  """
60
  batches = make_batches(len(pairs_train), batch_size)
61
  while 1:
62
    random_idx = np.random.permutation(len(pairs_train))
63
    for batch_start, batch_end in batches:
64
      p_ = random_idx[batch_start:batch_end]
65
      x1, x2 = pairs_train[p_, 0], pairs_train[p_, 1]
66
      y = dist_train[p_]
67
      yield ([x1, x2], y)
68

69

70
def make_layer_list(arch, network_type=None, reg=None, dropout=0):
71
  """generates the list of layers.
72

73
  generates the list of layers specified by arch, to be stacked
74
  by stack_layers
75

76
  Args:
77
    arch:           list of dicts, where each dict contains the arguments to the
78
      corresponding layer function in stack_layers
79
    network_type:   siamese or cnc net. used only to name layers
80
    reg:            L2 regularization (if any)
81
    dropout:        dropout (if any)
82

83
  Returns:
84
    appropriately formatted stack_layers dictionary
85
  """
86
  layers = []
87
  for i, a in enumerate(arch):
88
    layer = {'l2_reg': reg}
89
    layer.update(a)
90
    if network_type:
91
      layer['name'] = '{}_{}'.format(network_type, i)
92
    layers.append(layer)
93
    if a['type'] != 'Flatten' and dropout != 0:
94
      dropout_layer = {
95
          'type': 'Dropout',
96
          'rate': dropout,
97
      }
98
      if network_type:
99
        dropout_layer['name'] = '{}_dropout_{}'.format(network_type, i)
100
      layers.append(dropout_layer)
101
  return layers
102

103

104
class LearningHandler(Callback):
105
  """Class for managing the learning rate scheduling and early stopping criteria.
106

107
  Learning rate scheduling is implemented by multiplying the learning rate
108
  by 'drop' everytime the validation loss does not see any improvement
109
  for 'patience' training steps
110
  """
111

112
  def __init__(self,
113
               lr,
114
               drop,
115
               lr_tensor,
116
               patience,
117
               tau_tensor=None,
118
               tau=1,
119
               min_tem=1,
120
               gumble=False):
121
    """initializer.
122

123
    Args:
124
      lr: initial learning rate
125
      drop: factor by which learning rate is reduced
126
      lr_tensor: tensorflow (or keras) tensor for the learning rate
127
      patience: patience of the learning rate scheduler
128
      tau_tensor: tensor to kepp the changed temperature
129
      tau: temperature
130
      min_tem: minimum temperature
131
      gumble: True if gumble is used
132
    """
133
    super(LearningHandler, self).__init__()
134
    self.lr = lr
135
    self.drop = drop
136
    self.lr_tensor = lr_tensor
137
    self.patience = patience
138
    self.tau = tau
139
    self.tau_tensor = tau_tensor
140
    self.min_tem = min_tem
141
    self.gumble = gumble
142

143
  def on_train_begin(self, logs=None):
144
    """Initialize the parameters at the start of training."""
145
    self.assign_op = tf.no_op()
146
    self.scheduler_stage = 0
147
    self.best_loss = np.inf
148
    self.wait = 0
149

150
  def on_epoch_end(self, epoch, logs=None):
151
    """For managing learning rate, early stopping, and temperature."""
152
    stop_training = False
153
    min_tem = self.min_tem
154
    anneal_rate = 0.00003
155
    if self.gumble and epoch % 20 == 0:
156
      self.tau = np.maximum(self.tau * np.exp(-anneal_rate * epoch), min_tem)
157
      K.set_value(self.tau_tensor, self.tau)
158
    # check if we need to stop or increase scheduler stage
159
    if isinstance(logs, dict):
160
      loss = logs['loss']
161
    else:
162
      loss = logs
163
    if loss <= self.best_loss:
164
      self.best_loss = loss
165
      self.wait = 0
166
    else:
167
      self.wait += 1
168
      if self.wait > self.patience:
169
        self.scheduler_stage += 1
170
        self.wait = 0
171
    if math.isnan(loss):
172
      stop_training = True
173
    # calculate and set learning rate
174
    lr = self.lr * np.power(self.drop, self.scheduler_stage)
175
    K.set_value(self.lr_tensor, lr)
176

177
    # built in stopping if lr is way too small
178
    if lr <= 1e-9:
179
      stop_training = True
180

181
    # for keras
182
    if hasattr(self, 'model') and self.model is not None:
183
      self.model.stop_training = stop_training
184

185
    return stop_training
186

187

188
def sample_gumbel(shape, eps=1e-20):
189
  """Sample from Gumbel(0, 1)."""
190
  samples = tf.random_uniform(shape, minval=0, maxval=1)
191
  return -tf.log(-tf.log(samples + eps) + eps)
192

193

194
def gumbel_softmax_sample(logits, temperature):
195
  """Draw a sample from the Gumbel-Softmax distribution."""
196
  y = logits + sample_gumbel(tf.shape(logits))
197
  return tf.nn.softmax(y / temperature)
198

199

200
def gumbel_softmax(logits, temperature, hard=False):
201
  """Sample from the Gumbel-Softmax distribution and optionally discretize.
202

203
  Args:
204
    logits: [batch_size, n_class] unnormalized log-probs
205
    temperature: non-negative scalar
206
    hard: if True, take argmax, but differentiate w.r.t. soft sample y
207

208
  Returns:
209
    [batch_size, n_class] sample from the Gumbel-Softmax distribution.
210
    If hard=True, then the returned sample will be one-hot, otherwise it
211
    will be a probabilitiy distribution that sums to 1 across classes
212
  """
213
  y = gumbel_softmax_sample(logits, temperature)
214
  if hard:
215
    y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)), y.dtype)
216
    y = tf.stop_gradient(y_hard - y) + y
217
  return y
218

219

220
def get_scale(x, batch_size, n_nbrs):
221
  """Calculates the scale.
222

223
    The scale is based on the median distance of the kth
224
    neighbors of each point of x*, a m-sized sample of x, where
225
    k = n_nbrs and m = batch_size
226

227
  Args:
228
    x:          data for which to compute scale.
229
    batch_size: m in the aforementioned calculation.
230
    n_nbrs:     k in the aforementeiond calculation.
231

232
  Returns:
233
    the scale which is the variance term of the gaussian affinity matrix used by
234
    ncutnet
235
  """
236
  n = len(x)
237

238
  # sample a random batch of size batch_size
239
  sample = x[np.random.randint(n, size=batch_size), :]
240
  # flatten it
241
  sample = sample.reshape((batch_size, np.prod(sample.shape[1:])))
242

243
  # compute distances of the nearest neighbors
244
  nbrs = NearestNeighbors(n_neighbors=n_nbrs).fit(sample)
245
  distances, _ = nbrs.kneighbors(sample)
246

247
  # return the median distance
248
  return np.median(distances[:, n_nbrs - 1])
249

250

251
def calculate_cost_matrix(cluster, n_clusters):
252
  cost_matrix = np.zeros((n_clusters, n_clusters))
253

254
  # cost_matrix[i,j] will be the cost of assigning cluster i to label j
255
  for j in range(n_clusters):
256
    s = np.sum(cluster[:, j])  # number of examples in cluster i
257
    for i in range(n_clusters):
258
      t = cluster[i, j]
259
      cost_matrix[j, i] = s - t
260
  return cost_matrix
261

262

263
def get_cluster_labels_from_indices(indices):
264
  n_clusters = len(indices)
265
  cluster_labels = np.zeros(n_clusters)
266
  for i in range(n_clusters):
267
    cluster_labels[i] = indices[i][1]
268
  return cluster_labels
269

270

271
def get_accuracy(cluster_assignments, y_true, n_clusters):
272
  """Computes accuracy.
273

274
  Computes the accuracy based on the cluster assignments
275
  and true labels, using the Munkres algorithm
276

277
  Args:
278
    cluster_assignments:    array of labels, outputted by kmeans
279
    y_true:                 true labels
280
    n_clusters:             number of clusters in the dataset
281

282
  Returns:
283
    a tuple containing the accuracy and confusion matrix, in that order
284
  """
285
  y_pred, confusion_matrix = get_y_preds(cluster_assignments, y_true,
286
                                         n_clusters)
287
  # calculate the accuracy
288
  return np.mean(y_pred == y_true), confusion_matrix
289

290

291
def print_accuracy(cluster_assignments,
292
                   y_true,
293
                   n_clusters,
294
                   extra_identifier=''):
295
  """prints the accuracy."""
296
  # get accuracy
297
  accuracy, confusion_matrix = get_accuracy(cluster_assignments, y_true,
298
                                            n_clusters)
299
  # get the confusion matrix
300
  print('confusion matrix{}: '.format(extra_identifier))
301
  print(confusion_matrix)
302
  print(('Cnc_net{} accuracy: '.format(extra_identifier) +
303
         str(np.round(accuracy, 3))))
304
  return str(np.round(accuracy, 3))
305

306

307
def get_y_preds(cluster_assignments, y_true, n_clusters):
308
  """Computes the predicted labels.
309

310
  Label assignments now correspond to the actual labels in
311
  y_true (as estimated by Munkres)
312

313
  Args:
314
    cluster_assignments:    array of labels, outputted by kmeans
315
    y_true:                 true labels
316
    n_clusters:             number of clusters in the dataset
317

318
  Returns:
319
    a tuple containing the accuracy and confusion matrix, in that order
320
  """
321
  confusion_matrix = sklearn.metrics.confusion_matrix(
322
      y_true, cluster_assignments, labels=None)
323
  # compute accuracy based on optimal 1:1 assignment of clusters to labels
324
  cost_matrix = calculate_cost_matrix(confusion_matrix, n_clusters)
325
  indices = Munkres().compute(cost_matrix)
326
  true_cluster_labels = get_cluster_labels_from_indices(indices)
327
  y_pred = true_cluster_labels[cluster_assignments]
328
  return y_pred, confusion_matrix
329

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.