google-research
78 строк · 2.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Library of functions for model calibration."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import scipy.optimize23import scipy.special24import tensorflow.compat.v2 as tf25import tensorflow_probability as tfp26
27from uq_benchmark_2019 import uq_utils28
29
30def find_scaling_temperature(labels, logits, temp_range=(1e-5, 1e5)):31"""Find max likelihood scaling temperature using binary search.32
33Args:
34labels: Integer labels (shape=[num_samples]).
35logits: Floating point softmax inputs (shape=[num_samples, num_classes]).
36temp_range: 2-tuple range of temperatures to consider.
37Returns:
38Floating point temperature value.
39"""
40if not tf.executing_eagerly():41raise NotImplementedError(42'find_scaling_temperature() not implemented for graph-mode TF')43if len(labels.shape) != 1:44raise ValueError('Invalid labels shape=%s' % str(labels.shape))45if len(logits.shape) not in (1, 2):46raise ValueError('Invalid logits shape=%s' % str(logits.shape))47if len(labels.shape) != 1 or len(labels) != len(logits):48raise ValueError('Incompatible shapes for logits (%s) vs labels (%s).' %49(logits.shape, labels.shape))50
51@tf.function(autograph=False)52def grad_fn(temperature):53"""Returns gradient of log-likelihood WRT a logits-scaling temperature."""54temperature *= tf.ones([])55if len(logits.shape) == 1:56dist = tfp.distributions.Bernoulli(logits=logits / temperature)57elif len(logits.shape) == 2:58dist = tfp.distributions.Categorical(logits=logits / temperature)59nll = -dist.log_prob(labels)60nll = tf.reduce_sum(nll, axis=0)61grad, = tf.gradients(nll, [temperature])62return grad63
64tmin, tmax = temp_range65return scipy.optimize.bisect(lambda t: grad_fn(t).numpy(), tmin, tmax)66
67
68def apply_temperature_scaling(temperature, probs):69"""Apply temperature scaling to an array of probabilities.70
71Args:
72temperature: Floating point temperature.
73probs: Array of probabilities with probabilities over axis=-1.
74Returns:
75Temperature-scaled probabilities; same shape as input probs.
76"""
77logits_t = uq_utils.np_inverse_softmax(probs).T / temperature78return scipy.special.softmax(logits_t.T, axis=-1)79