google-research

Форк
0
/
calibration_lib.py 
78 строк · 2.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Library of functions for model calibration."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import scipy.optimize
23
import scipy.special
24
import tensorflow.compat.v2 as tf
25
import tensorflow_probability as tfp
26

27
from uq_benchmark_2019 import uq_utils
28

29

30
def find_scaling_temperature(labels, logits, temp_range=(1e-5, 1e5)):
31
  """Find max likelihood scaling temperature using binary search.
32

33
  Args:
34
    labels: Integer labels (shape=[num_samples]).
35
    logits: Floating point softmax inputs (shape=[num_samples, num_classes]).
36
    temp_range: 2-tuple range of temperatures to consider.
37
  Returns:
38
    Floating point temperature value.
39
  """
40
  if not tf.executing_eagerly():
41
    raise NotImplementedError(
42
        'find_scaling_temperature() not implemented for graph-mode TF')
43
  if len(labels.shape) != 1:
44
    raise ValueError('Invalid labels shape=%s' % str(labels.shape))
45
  if len(logits.shape) not in (1, 2):
46
    raise ValueError('Invalid logits shape=%s' % str(logits.shape))
47
  if len(labels.shape) != 1 or len(labels) != len(logits):
48
    raise ValueError('Incompatible shapes for logits (%s) vs labels (%s).' %
49
                     (logits.shape, labels.shape))
50

51
  @tf.function(autograph=False)
52
  def grad_fn(temperature):
53
    """Returns gradient of log-likelihood WRT a logits-scaling temperature."""
54
    temperature *= tf.ones([])
55
    if len(logits.shape) == 1:
56
      dist = tfp.distributions.Bernoulli(logits=logits / temperature)
57
    elif len(logits.shape) == 2:
58
      dist = tfp.distributions.Categorical(logits=logits / temperature)
59
    nll = -dist.log_prob(labels)
60
    nll = tf.reduce_sum(nll, axis=0)
61
    grad, = tf.gradients(nll, [temperature])
62
    return grad
63

64
  tmin, tmax = temp_range
65
  return scipy.optimize.bisect(lambda t: grad_fn(t).numpy(), tmin, tmax)
66

67

68
def apply_temperature_scaling(temperature, probs):
69
  """Apply temperature scaling to an array of probabilities.
70

71
  Args:
72
    temperature: Floating point temperature.
73
    probs: Array of probabilities with probabilities over axis=-1.
74
  Returns:
75
    Temperature-scaled probabilities; same shape as input probs.
76
  """
77
  logits_t = uq_utils.np_inverse_softmax(probs).T / temperature
78
  return scipy.special.softmax(logits_t.T, axis=-1)
79

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.