google-research

Форк
0
54 строки · 1.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for uq_utils."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import numpy as np
23
import scipy.special
24
import tensorflow.compat.v2 as tf
25

26
from uq_benchmark_2019 import uq_utils
27

28

29
class UqUtilsTest(tf.test.TestCase):
30

31
  def test_np_inverse_softmax(self):
32
    batch_size, nclasses = [4, 3]
33
    logits_orig = np.random.rand(batch_size, nclasses)
34
    probs_orig = scipy.special.softmax(logits_orig, axis=-1)
35
    logits_new = uq_utils.np_inverse_softmax(probs_orig)
36
    probs_new = scipy.special.softmax(logits_new, axis=-1)
37
    self.assertAllClose(probs_orig, probs_new)
38

39
  def test_np_soften_probabilities(self):
40
    shape = [12, 5]
41
    logits = np.random.uniform(0, 1, size=shape)
42
    probs = scipy.special.softmax(logits, axis=-1)
43
    probs[0] = 0
44
    probs[0, 0] = 1
45
    soft_probs = uq_utils.np_soften_probabilities(probs, epsilon=1e-8)
46
    self.assertAllClose(probs[1:], soft_probs[1:])
47
    self.assertAllLess(soft_probs, 1)
48
    self.assertAllGreater(soft_probs, 0)
49
    self.assertAllClose(np.ones(shape[0]), soft_probs.sum(1), atol=1e-10)
50

51

52
if __name__ == '__main__':
53
  tf.enable_v2_behavior()
54
  tf.test.main()
55

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.