google-research
328 строк · 10.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contains various utility functions used in the models."""
17from __future__ import division
18
19from __future__ import print_function
20import math
21
22from munkres import Munkres
23import numpy as np
24import sklearn.metrics
25from sklearn.neighbors import NearestNeighbors
26import tensorflow.compat.v1 as tf
27from tensorflow.compat.v1.keras import backend as K
28from tensorflow.compat.v1.keras.callbacks import Callback
29
30
31def make_batches(size, batch_size):
32"""generates a list of tuples for batching data.
33
34generates a list of (start_idx, end_idx) tuples for batching data
35of the given size and batch_size
36
37Args:
38size: size of the data to create batches for
39batch_size: batch size
40
41Returns:
42list of tuples of indices for data
43"""
44num_batches = (size + batch_size - 1) // batch_size # round up
45return [(i * batch_size, min(size, (i + 1) * batch_size))
46for i in range(num_batches)]
47
48
49def train_gen(pairs_train, dist_train, batch_size):
50"""generator used for training the siamese net with keras.
51
52Args:
53pairs_train: training pairs
54dist_train: training labels
55batch_size: batch size
56
57Yields:
58generator instance
59"""
60batches = make_batches(len(pairs_train), batch_size)
61while 1:
62random_idx = np.random.permutation(len(pairs_train))
63for batch_start, batch_end in batches:
64p_ = random_idx[batch_start:batch_end]
65x1, x2 = pairs_train[p_, 0], pairs_train[p_, 1]
66y = dist_train[p_]
67yield ([x1, x2], y)
68
69
70def make_layer_list(arch, network_type=None, reg=None, dropout=0):
71"""generates the list of layers.
72
73generates the list of layers specified by arch, to be stacked
74by stack_layers
75
76Args:
77arch: list of dicts, where each dict contains the arguments to the
78corresponding layer function in stack_layers
79network_type: siamese or cnc net. used only to name layers
80reg: L2 regularization (if any)
81dropout: dropout (if any)
82
83Returns:
84appropriately formatted stack_layers dictionary
85"""
86layers = []
87for i, a in enumerate(arch):
88layer = {'l2_reg': reg}
89layer.update(a)
90if network_type:
91layer['name'] = '{}_{}'.format(network_type, i)
92layers.append(layer)
93if a['type'] != 'Flatten' and dropout != 0:
94dropout_layer = {
95'type': 'Dropout',
96'rate': dropout,
97}
98if network_type:
99dropout_layer['name'] = '{}_dropout_{}'.format(network_type, i)
100layers.append(dropout_layer)
101return layers
102
103
104class LearningHandler(Callback):
105"""Class for managing the learning rate scheduling and early stopping criteria.
106
107Learning rate scheduling is implemented by multiplying the learning rate
108by 'drop' everytime the validation loss does not see any improvement
109for 'patience' training steps
110"""
111
112def __init__(self,
113lr,
114drop,
115lr_tensor,
116patience,
117tau_tensor=None,
118tau=1,
119min_tem=1,
120gumble=False):
121"""initializer.
122
123Args:
124lr: initial learning rate
125drop: factor by which learning rate is reduced
126lr_tensor: tensorflow (or keras) tensor for the learning rate
127patience: patience of the learning rate scheduler
128tau_tensor: tensor to kepp the changed temperature
129tau: temperature
130min_tem: minimum temperature
131gumble: True if gumble is used
132"""
133super(LearningHandler, self).__init__()
134self.lr = lr
135self.drop = drop
136self.lr_tensor = lr_tensor
137self.patience = patience
138self.tau = tau
139self.tau_tensor = tau_tensor
140self.min_tem = min_tem
141self.gumble = gumble
142
143def on_train_begin(self, logs=None):
144"""Initialize the parameters at the start of training."""
145self.assign_op = tf.no_op()
146self.scheduler_stage = 0
147self.best_loss = np.inf
148self.wait = 0
149
150def on_epoch_end(self, epoch, logs=None):
151"""For managing learning rate, early stopping, and temperature."""
152stop_training = False
153min_tem = self.min_tem
154anneal_rate = 0.00003
155if self.gumble and epoch % 20 == 0:
156self.tau = np.maximum(self.tau * np.exp(-anneal_rate * epoch), min_tem)
157K.set_value(self.tau_tensor, self.tau)
158# check if we need to stop or increase scheduler stage
159if isinstance(logs, dict):
160loss = logs['loss']
161else:
162loss = logs
163if loss <= self.best_loss:
164self.best_loss = loss
165self.wait = 0
166else:
167self.wait += 1
168if self.wait > self.patience:
169self.scheduler_stage += 1
170self.wait = 0
171if math.isnan(loss):
172stop_training = True
173# calculate and set learning rate
174lr = self.lr * np.power(self.drop, self.scheduler_stage)
175K.set_value(self.lr_tensor, lr)
176
177# built in stopping if lr is way too small
178if lr <= 1e-9:
179stop_training = True
180
181# for keras
182if hasattr(self, 'model') and self.model is not None:
183self.model.stop_training = stop_training
184
185return stop_training
186
187
188def sample_gumbel(shape, eps=1e-20):
189"""Sample from Gumbel(0, 1)."""
190samples = tf.random_uniform(shape, minval=0, maxval=1)
191return -tf.log(-tf.log(samples + eps) + eps)
192
193
194def gumbel_softmax_sample(logits, temperature):
195"""Draw a sample from the Gumbel-Softmax distribution."""
196y = logits + sample_gumbel(tf.shape(logits))
197return tf.nn.softmax(y / temperature)
198
199
200def gumbel_softmax(logits, temperature, hard=False):
201"""Sample from the Gumbel-Softmax distribution and optionally discretize.
202
203Args:
204logits: [batch_size, n_class] unnormalized log-probs
205temperature: non-negative scalar
206hard: if True, take argmax, but differentiate w.r.t. soft sample y
207
208Returns:
209[batch_size, n_class] sample from the Gumbel-Softmax distribution.
210If hard=True, then the returned sample will be one-hot, otherwise it
211will be a probabilitiy distribution that sums to 1 across classes
212"""
213y = gumbel_softmax_sample(logits, temperature)
214if hard:
215y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)), y.dtype)
216y = tf.stop_gradient(y_hard - y) + y
217return y
218
219
220def get_scale(x, batch_size, n_nbrs):
221"""Calculates the scale.
222
223The scale is based on the median distance of the kth
224neighbors of each point of x*, a m-sized sample of x, where
225k = n_nbrs and m = batch_size
226
227Args:
228x: data for which to compute scale.
229batch_size: m in the aforementioned calculation.
230n_nbrs: k in the aforementeiond calculation.
231
232Returns:
233the scale which is the variance term of the gaussian affinity matrix used by
234ncutnet
235"""
236n = len(x)
237
238# sample a random batch of size batch_size
239sample = x[np.random.randint(n, size=batch_size), :]
240# flatten it
241sample = sample.reshape((batch_size, np.prod(sample.shape[1:])))
242
243# compute distances of the nearest neighbors
244nbrs = NearestNeighbors(n_neighbors=n_nbrs).fit(sample)
245distances, _ = nbrs.kneighbors(sample)
246
247# return the median distance
248return np.median(distances[:, n_nbrs - 1])
249
250
251def calculate_cost_matrix(cluster, n_clusters):
252cost_matrix = np.zeros((n_clusters, n_clusters))
253
254# cost_matrix[i,j] will be the cost of assigning cluster i to label j
255for j in range(n_clusters):
256s = np.sum(cluster[:, j]) # number of examples in cluster i
257for i in range(n_clusters):
258t = cluster[i, j]
259cost_matrix[j, i] = s - t
260return cost_matrix
261
262
263def get_cluster_labels_from_indices(indices):
264n_clusters = len(indices)
265cluster_labels = np.zeros(n_clusters)
266for i in range(n_clusters):
267cluster_labels[i] = indices[i][1]
268return cluster_labels
269
270
271def get_accuracy(cluster_assignments, y_true, n_clusters):
272"""Computes accuracy.
273
274Computes the accuracy based on the cluster assignments
275and true labels, using the Munkres algorithm
276
277Args:
278cluster_assignments: array of labels, outputted by kmeans
279y_true: true labels
280n_clusters: number of clusters in the dataset
281
282Returns:
283a tuple containing the accuracy and confusion matrix, in that order
284"""
285y_pred, confusion_matrix = get_y_preds(cluster_assignments, y_true,
286n_clusters)
287# calculate the accuracy
288return np.mean(y_pred == y_true), confusion_matrix
289
290
291def print_accuracy(cluster_assignments,
292y_true,
293n_clusters,
294extra_identifier=''):
295"""prints the accuracy."""
296# get accuracy
297accuracy, confusion_matrix = get_accuracy(cluster_assignments, y_true,
298n_clusters)
299# get the confusion matrix
300print('confusion matrix{}: '.format(extra_identifier))
301print(confusion_matrix)
302print(('Cnc_net{} accuracy: '.format(extra_identifier) +
303str(np.round(accuracy, 3))))
304return str(np.round(accuracy, 3))
305
306
307def get_y_preds(cluster_assignments, y_true, n_clusters):
308"""Computes the predicted labels.
309
310Label assignments now correspond to the actual labels in
311y_true (as estimated by Munkres)
312
313Args:
314cluster_assignments: array of labels, outputted by kmeans
315y_true: true labels
316n_clusters: number of clusters in the dataset
317
318Returns:
319a tuple containing the accuracy and confusion matrix, in that order
320"""
321confusion_matrix = sklearn.metrics.confusion_matrix(
322y_true, cluster_assignments, labels=None)
323# compute accuracy based on optimal 1:1 assignment of clusters to labels
324cost_matrix = calculate_cost_matrix(confusion_matrix, n_clusters)
325indices = Munkres().compute(cost_matrix)
326true_cluster_labels = get_cluster_labels_from_indices(indices)
327y_pred = true_cluster_labels[cluster_assignments]
328return y_pred, confusion_matrix
329