google-research

Форк
0
332 строки · 9.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tensorflow utils."""
17

18
from typing import List, Tuple
19
import tensorflow as tf
20

21

22
@tf.function
23
def get_model_feature(
24
    model,
25
    batch_x
26
):
27
  """Gets model's features on the given inputs."""
28
  features = model.get_feature(batch_x, training=False)
29
  return features
30

31

32
@tf.function
33
def get_model_output(
34
    model,
35
    batch_x
36
):
37
  """Gets model's outputs on the given inputs."""
38
  outputs = model(batch_x, training=False)
39
  return outputs
40

41

42
@tf.function
43
def get_model_output_and_feature(
44
    model,
45
    batch_x
46
):
47
  """Gets model's outputs and features on the given inputs."""
48
  outputs, features = model.get_output_and_feature(batch_x, training=False)
49
  return outputs, features
50

51

52
@tf.function
53
def get_model_prediction(
54
    model,
55
    batch_x
56
):
57
  """Gets model's predictions on the given inputs."""
58
  outputs = model(batch_x, training=False)
59
  preds = tf.argmax(outputs, axis=1)
60
  return preds
61

62

63
@tf.function
64
def get_model_confidence(
65
    model,
66
    batch_x
67
):
68
  """Gets model's confidences on the given inputs."""
69
  outputs = model(batch_x, training=False)
70
  confs = tf.math.reduce_max(outputs, axis=1)
71
  return confs
72

73

74
@tf.function
75
def get_model_margin(
76
    model,
77
    batch_x
78
):
79
  """Gets model's margins on the given inputs."""
80
  outputs = model(batch_x, training=False)
81
  sorted_outputs = tf.sort(outputs, direction='DESCENDING', axis=1)
82
  margins = sorted_outputs[:, 0] - sorted_outputs[:, 1]
83
  return margins
84

85

86
@tf.function
87
def get_ensemble_model_output(
88
    models,
89
    batch_x,
90
    ensemble_method
91
):
92
  """Gets ensemble model's outputs on the given inputs."""
93
  batch_ensemble_output = 0
94
  if ensemble_method == 'hard':
95
    num_classes = None
96
  for model in models:
97
    batch_output = model(batch_x, training=False)
98
    if ensemble_method == 'hard':
99
      batch_pred = tf.argmax(batch_output, axis=1)
100
      if num_classes is None:
101
        num_classes = batch_output.shape[1]
102
      batch_one_hot_output = tf.one_hot(batch_pred, num_classes)
103
      batch_ensemble_output += batch_one_hot_output
104
    elif ensemble_method == 'soft':
105
      batch_ensemble_output += batch_output
106
    else:
107
      raise ValueError(f'Not supported ensemble method: {ensemble_method}!')
108
  return batch_ensemble_output / len(models)
109

110

111
@tf.function
112
def get_ensemble_model_feature(
113
    models,
114
    batch_x
115
):
116
  """Gets ensemble model's features on the given inputs."""
117
  batch_feature_list = []
118
  for model in models:
119
    batch_feature = model.get_feature(batch_x, training=False)
120
    batch_feature_list.append(batch_feature)
121
  # Concatenates the features of the models in the ensemble.
122
  concat_batch_feature = tf.concat(batch_feature_list, axis=1)
123
  return concat_batch_feature
124

125

126
@tf.function
127
def get_ensemble_model_output_and_feature(
128
    models,
129
    batch_x,
130
    ensemble_method,
131
    temperature = 1.0,
132
):
133
  """Gets ensemble model's outputs and features on the given inputs."""
134
  batch_ensemble_output = 0
135
  batch_feature_list = []
136
  if ensemble_method == 'hard':
137
    num_classes = None
138
  for model in models:
139
    batch_output, batch_feature = model.get_output_and_feature(
140
        batch_x, training=False, temperature=temperature,
141
    )
142
    batch_feature_list.append(batch_feature)
143
    if ensemble_method == 'hard':
144
      batch_pred = tf.argmax(batch_output, axis=1)
145
      if num_classes is None:
146
        num_classes = batch_output.shape[1]
147
      batch_one_hot_output = tf.one_hot(batch_pred, num_classes)
148
      batch_ensemble_output += batch_one_hot_output
149
    elif ensemble_method == 'soft':
150
      batch_ensemble_output += batch_output
151
    else:
152
      raise ValueError(f'Not supported ensemble method: {ensemble_method}!')
153
  # Concatenates the features of the models in the ensemble.
154
  concat_batch_feature = tf.concat(batch_feature_list, axis=1)
155
  return batch_ensemble_output / len(models), concat_batch_feature
156

157

158
@tf.function
159
def get_ensemble_model_prediction(
160
    models,
161
    batch_x,
162
    ensemble_method,
163
):
164
  """Gets ensemble model's predictions on the given inputs.
165

166
  Args:
167
    models: a list of models
168
    batch_x: a batch of inputs
169
    ensemble_method: the method to construct ensemble
170

171
  Returns:
172
    The ensemble model's predictions
173
  """
174
  batch_ensemble_output = 0
175
  if ensemble_method == 'hard':
176
    num_classes = None
177
  for model in models:
178
    batch_output = model(batch_x, training=False)
179
    if ensemble_method == 'hard':
180
      batch_pred = tf.argmax(batch_output, axis=1)
181
      if num_classes is None:
182
        num_classes = batch_output.shape[1]
183
      batch_one_hot_output = tf.one_hot(batch_pred, num_classes)
184
      batch_ensemble_output += batch_one_hot_output
185
    elif ensemble_method == 'soft':
186
      batch_ensemble_output += batch_output
187
    else:
188
      raise ValueError(f'Not supported ensemble method: {ensemble_method}!')
189
  batch_preds = tf.argmax(batch_ensemble_output / len(models), axis=1)
190
  return batch_preds
191

192

193
@tf.function
194
def get_ensemble_model_confidence(
195
    models,
196
    batch_x,
197
    ensemble_method
198
):
199
  """Gets ensemble model's confidences on the given inputs.
200

201
  Args:
202
    models: a list of models
203
    batch_x: a batch of inputs
204
    ensemble_method: the method to construct ensemble
205

206
  Returns:
207
    The ensemble model's confidences
208
  """
209
  batch_ensemble_output = 0
210
  if ensemble_method == 'hard':
211
    num_classes = None
212
  for model in models:
213
    batch_output = model(batch_x, training=False)
214
    if ensemble_method == 'hard':
215
      batch_pred = tf.argmax(batch_output, axis=1)
216
      if num_classes is None:
217
        num_classes = batch_output.shape[1]
218
      batch_one_hot_output = tf.one_hot(batch_pred, num_classes)
219
      batch_ensemble_output += batch_one_hot_output
220
    elif ensemble_method == 'soft':
221
      batch_ensemble_output += batch_output
222
    else:
223
      raise ValueError(f'Not supported ensemble method: {ensemble_method}!')
224
  batch_confs = tf.math.reduce_max(batch_ensemble_output / len(models), axis=1)
225
  return batch_confs
226

227

228
@tf.function
229
def get_ensemble_model_margin(
230
    models,
231
    batch_x,
232
    ensemble_method
233
):
234
  """Gets ensemble model's margins on the given inputs.
235

236
  Args:
237
    models: a list of models
238
    batch_x: a batch of inputs
239
    ensemble_method: the method to construct ensemble
240

241
  Returns:
242
    The ensemble model's margins
243
  """
244
  batch_ensemble_output = 0
245
  if ensemble_method == 'hard':
246
    num_classes = None
247
  for model in models:
248
    batch_output = model(batch_x, training=False)
249
    if ensemble_method == 'hard':
250
      batch_pred = tf.argmax(batch_output, axis=1)
251
      if num_classes is None:
252
        num_classes = batch_output.shape[1]
253
      batch_one_hot_output = tf.one_hot(batch_pred, num_classes)
254
      batch_ensemble_output += batch_one_hot_output
255
    elif ensemble_method == 'soft':
256
      batch_ensemble_output += batch_output
257
    else:
258
      raise ValueError(f'Not supported ensemble method: {ensemble_method}!')
259
  batch_ensemble_output = batch_ensemble_output / len(models)
260
  batch_sorted_ensemble_outputs = tf.sort(
261
      batch_ensemble_output, direction='DESCENDING', axis=1
262
  )
263
  batch_margins = (
264
      batch_sorted_ensemble_outputs[:, 0] - batch_sorted_ensemble_outputs[:, 1]
265
  )
266
  return batch_margins
267

268

269
def evaluate_acc(
270
    model,
271
    ds
272
):
273
  """Evaluates model's accuracy on the dataset."""
274
  n = 0
275
  correct = 0
276
  for batch_x, batch_y in ds:
277
    batch_pred = get_model_prediction(model, batch_x)
278
    correct += tf.math.reduce_sum(
279
        tf.cast(batch_pred == batch_y, dtype=tf.int32)
280
    )
281
    n += batch_y.shape[0]
282
  return correct / n
283

284

285
def evaluate_ensemble_acc(
286
    models,
287
    ds
288
):
289
  """Evaluates ensemble's accuracy on the dataset."""
290
  n = 0
291
  correct = 0
292
  for batch_x, batch_y in ds:
293
    batch_pred = get_ensemble_model_prediction(
294
        models,
295
        batch_x,
296
        ensemble_method='soft',
297
    )
298
    correct += tf.math.reduce_sum(
299
        tf.cast(batch_pred == batch_y, dtype=tf.int32)
300
    )
301
    n += batch_y.shape[0]
302
  return correct / n
303

304

305
def evaluate_loss(
306
    model,
307
    ds,
308
    loss_func_name = 'CE'
309
):
310
  """Evaluates model's cross-entropy loss on the dataset."""
311
  loss = 0
312
  if loss_func_name == 'CE':
313
    loss_func = tf.keras.losses.SparseCategoricalCrossentropy(
314
        reduction=tf.keras.losses.Reduction.SUM
315
    )
316
  else:
317
    raise ValueError(f'Not supported loss function {loss_func_name}!')
318
  n = 0
319
  for batch_x, batch_y in ds:
320
    batch_output = get_model_output(model, batch_x)
321
    loss += loss_func(batch_y, batch_output)
322
    n += batch_y.shape[0]
323
  return loss / n
324

325

326
def entropy_loss(
327
    outputs,
328
    epsilon = 1e-6
329
):
330
  """Computes entropy loss."""
331
  loss = -tf.reduce_sum(outputs*tf.math.log(outputs+epsilon), axis=1)
332
  return loss
333

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.