google-research
332 строки · 9.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tensorflow utils."""
17
18from typing import List, Tuple19import tensorflow as tf20
21
22@tf.function23def get_model_feature(24model,25batch_x
26):27"""Gets model's features on the given inputs."""28features = model.get_feature(batch_x, training=False)29return features30
31
32@tf.function33def get_model_output(34model,35batch_x
36):37"""Gets model's outputs on the given inputs."""38outputs = model(batch_x, training=False)39return outputs40
41
42@tf.function43def get_model_output_and_feature(44model,45batch_x
46):47"""Gets model's outputs and features on the given inputs."""48outputs, features = model.get_output_and_feature(batch_x, training=False)49return outputs, features50
51
52@tf.function53def get_model_prediction(54model,55batch_x
56):57"""Gets model's predictions on the given inputs."""58outputs = model(batch_x, training=False)59preds = tf.argmax(outputs, axis=1)60return preds61
62
63@tf.function64def get_model_confidence(65model,66batch_x
67):68"""Gets model's confidences on the given inputs."""69outputs = model(batch_x, training=False)70confs = tf.math.reduce_max(outputs, axis=1)71return confs72
73
74@tf.function75def get_model_margin(76model,77batch_x
78):79"""Gets model's margins on the given inputs."""80outputs = model(batch_x, training=False)81sorted_outputs = tf.sort(outputs, direction='DESCENDING', axis=1)82margins = sorted_outputs[:, 0] - sorted_outputs[:, 1]83return margins84
85
86@tf.function87def get_ensemble_model_output(88models,89batch_x,90ensemble_method
91):92"""Gets ensemble model's outputs on the given inputs."""93batch_ensemble_output = 094if ensemble_method == 'hard':95num_classes = None96for model in models:97batch_output = model(batch_x, training=False)98if ensemble_method == 'hard':99batch_pred = tf.argmax(batch_output, axis=1)100if num_classes is None:101num_classes = batch_output.shape[1]102batch_one_hot_output = tf.one_hot(batch_pred, num_classes)103batch_ensemble_output += batch_one_hot_output104elif ensemble_method == 'soft':105batch_ensemble_output += batch_output106else:107raise ValueError(f'Not supported ensemble method: {ensemble_method}!')108return batch_ensemble_output / len(models)109
110
111@tf.function112def get_ensemble_model_feature(113models,114batch_x
115):116"""Gets ensemble model's features on the given inputs."""117batch_feature_list = []118for model in models:119batch_feature = model.get_feature(batch_x, training=False)120batch_feature_list.append(batch_feature)121# Concatenates the features of the models in the ensemble.122concat_batch_feature = tf.concat(batch_feature_list, axis=1)123return concat_batch_feature124
125
126@tf.function127def get_ensemble_model_output_and_feature(128models,129batch_x,130ensemble_method,131temperature = 1.0,132):133"""Gets ensemble model's outputs and features on the given inputs."""134batch_ensemble_output = 0135batch_feature_list = []136if ensemble_method == 'hard':137num_classes = None138for model in models:139batch_output, batch_feature = model.get_output_and_feature(140batch_x, training=False, temperature=temperature,141)142batch_feature_list.append(batch_feature)143if ensemble_method == 'hard':144batch_pred = tf.argmax(batch_output, axis=1)145if num_classes is None:146num_classes = batch_output.shape[1]147batch_one_hot_output = tf.one_hot(batch_pred, num_classes)148batch_ensemble_output += batch_one_hot_output149elif ensemble_method == 'soft':150batch_ensemble_output += batch_output151else:152raise ValueError(f'Not supported ensemble method: {ensemble_method}!')153# Concatenates the features of the models in the ensemble.154concat_batch_feature = tf.concat(batch_feature_list, axis=1)155return batch_ensemble_output / len(models), concat_batch_feature156
157
158@tf.function159def get_ensemble_model_prediction(160models,161batch_x,162ensemble_method,163):164"""Gets ensemble model's predictions on the given inputs.165
166Args:
167models: a list of models
168batch_x: a batch of inputs
169ensemble_method: the method to construct ensemble
170
171Returns:
172The ensemble model's predictions
173"""
174batch_ensemble_output = 0175if ensemble_method == 'hard':176num_classes = None177for model in models:178batch_output = model(batch_x, training=False)179if ensemble_method == 'hard':180batch_pred = tf.argmax(batch_output, axis=1)181if num_classes is None:182num_classes = batch_output.shape[1]183batch_one_hot_output = tf.one_hot(batch_pred, num_classes)184batch_ensemble_output += batch_one_hot_output185elif ensemble_method == 'soft':186batch_ensemble_output += batch_output187else:188raise ValueError(f'Not supported ensemble method: {ensemble_method}!')189batch_preds = tf.argmax(batch_ensemble_output / len(models), axis=1)190return batch_preds191
192
193@tf.function194def get_ensemble_model_confidence(195models,196batch_x,197ensemble_method
198):199"""Gets ensemble model's confidences on the given inputs.200
201Args:
202models: a list of models
203batch_x: a batch of inputs
204ensemble_method: the method to construct ensemble
205
206Returns:
207The ensemble model's confidences
208"""
209batch_ensemble_output = 0210if ensemble_method == 'hard':211num_classes = None212for model in models:213batch_output = model(batch_x, training=False)214if ensemble_method == 'hard':215batch_pred = tf.argmax(batch_output, axis=1)216if num_classes is None:217num_classes = batch_output.shape[1]218batch_one_hot_output = tf.one_hot(batch_pred, num_classes)219batch_ensemble_output += batch_one_hot_output220elif ensemble_method == 'soft':221batch_ensemble_output += batch_output222else:223raise ValueError(f'Not supported ensemble method: {ensemble_method}!')224batch_confs = tf.math.reduce_max(batch_ensemble_output / len(models), axis=1)225return batch_confs226
227
228@tf.function229def get_ensemble_model_margin(230models,231batch_x,232ensemble_method
233):234"""Gets ensemble model's margins on the given inputs.235
236Args:
237models: a list of models
238batch_x: a batch of inputs
239ensemble_method: the method to construct ensemble
240
241Returns:
242The ensemble model's margins
243"""
244batch_ensemble_output = 0245if ensemble_method == 'hard':246num_classes = None247for model in models:248batch_output = model(batch_x, training=False)249if ensemble_method == 'hard':250batch_pred = tf.argmax(batch_output, axis=1)251if num_classes is None:252num_classes = batch_output.shape[1]253batch_one_hot_output = tf.one_hot(batch_pred, num_classes)254batch_ensemble_output += batch_one_hot_output255elif ensemble_method == 'soft':256batch_ensemble_output += batch_output257else:258raise ValueError(f'Not supported ensemble method: {ensemble_method}!')259batch_ensemble_output = batch_ensemble_output / len(models)260batch_sorted_ensemble_outputs = tf.sort(261batch_ensemble_output, direction='DESCENDING', axis=1262)263batch_margins = (264batch_sorted_ensemble_outputs[:, 0] - batch_sorted_ensemble_outputs[:, 1]265)266return batch_margins267
268
269def evaluate_acc(270model,271ds
272):273"""Evaluates model's accuracy on the dataset."""274n = 0275correct = 0276for batch_x, batch_y in ds:277batch_pred = get_model_prediction(model, batch_x)278correct += tf.math.reduce_sum(279tf.cast(batch_pred == batch_y, dtype=tf.int32)280)281n += batch_y.shape[0]282return correct / n283
284
285def evaluate_ensemble_acc(286models,287ds
288):289"""Evaluates ensemble's accuracy on the dataset."""290n = 0291correct = 0292for batch_x, batch_y in ds:293batch_pred = get_ensemble_model_prediction(294models,295batch_x,296ensemble_method='soft',297)298correct += tf.math.reduce_sum(299tf.cast(batch_pred == batch_y, dtype=tf.int32)300)301n += batch_y.shape[0]302return correct / n303
304
305def evaluate_loss(306model,307ds,308loss_func_name = 'CE'309):310"""Evaluates model's cross-entropy loss on the dataset."""311loss = 0312if loss_func_name == 'CE':313loss_func = tf.keras.losses.SparseCategoricalCrossentropy(314reduction=tf.keras.losses.Reduction.SUM315)316else:317raise ValueError(f'Not supported loss function {loss_func_name}!')318n = 0319for batch_x, batch_y in ds:320batch_output = get_model_output(model, batch_x)321loss += loss_func(batch_y, batch_output)322n += batch_y.shape[0]323return loss / n324
325
326def entropy_loss(327outputs,328epsilon = 1e-6329):330"""Computes entropy loss."""331loss = -tf.reduce_sum(outputs*tf.math.log(outputs+epsilon), axis=1)332return loss333