paddlenlp

Форк
0
/
test_perplexity.py 
102 строки · 4.0 Кб
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import unittest
15

16
import numpy as np
17
import paddle
18

19
from paddlenlp.metrics import Perplexity
20
from tests.common_test import CommonTest
21
from tests.testing_utils import cross_entropy, stable_softmax
22

23

24
class NpPerplexity(object):
25
    def __init__(self):
26
        self.total_ce = 0
27
        self.total_word_num = 0
28

29
    def compute(self, pred, label, seq_mask=None):
30
        label = np.expand_dims(label, axis=2)
31
        ce = cross_entropy(softmax=pred, label=label, soft_label=False, axis=-1, ignore_index=-100)
32
        ce = np.squeeze(ce, axis=2)
33
        if seq_mask is not None:
34
            ce = ce * seq_mask
35
            word_num = np.sum(seq_mask)
36
            return ce, word_num
37
        return ce
38

39
    def update(self, ce):
40
        self.total_ce += np.sum(ce)
41
        self.total_word_num += ce.size
42

43
    def accumulate(self):
44
        return np.exp(self.total_ce / self.total_word_num)
45

46

47
class TestPerplexity(CommonTest):
48
    def setUp(self):
49
        self.config["name"] = "test_perplexity"
50
        self.cls_num = 10
51
        self.shape = (5, 20, self.cls_num)
52
        self.label_shape = (5, 20)
53
        self.metrics = Perplexity(**self.config)
54
        self.np_metrics = NpPerplexity()
55

56
    def get_random_case(self):
57
        label = np.random.randint(self.cls_num, size=self.label_shape).astype("int64")
58
        logits = np.random.uniform(0.1, 1.0, self.shape).astype(paddle.get_default_dtype())
59
        pred = np.apply_along_axis(stable_softmax, -1, logits)
60
        seq_mask = np.random.randint(2, size=self.label_shape).astype("int64")
61
        return label, logits, pred, seq_mask
62

63
    def test_name(self):
64
        self.check_output_equal(self.metrics.name(), self.config["name"])
65

66
    def test_compute(self):
67
        label, logits, pred, _ = self.get_random_case()
68
        expected_result = self.np_metrics.compute(pred, label)
69
        result = self.metrics.compute(paddle.to_tensor(logits), paddle.to_tensor(label))
70
        self.check_output_equal(expected_result, result.numpy())
71

72
    def test_compute_with_mask(self):
73
        label, logits, pred, seq_mask = self.get_random_case()
74
        expected_result = self.np_metrics.compute(pred, label, seq_mask)
75
        result = self.metrics.compute(paddle.to_tensor(logits), paddle.to_tensor(label), paddle.to_tensor(seq_mask))
76
        self.check_output_equal(expected_result[0], result[0].numpy())
77
        self.check_output_equal(expected_result[1], result[1])
78

79
    def test_reset(self):
80
        label, logits, pred, _ = self.get_random_case()
81
        result = self.metrics.compute(paddle.to_tensor(logits), paddle.to_tensor(label))
82
        self.metrics.update(result.numpy())
83
        self.check_output_not_equal(self.metrics.total_ce, 0)
84
        self.check_output_not_equal(self.metrics.total_word_num, 0)
85

86
        self.metrics.reset()
87
        self.check_output_equal(self.metrics.total_ce, 0)
88
        self.check_output_equal(self.metrics.total_word_num, 0)
89

90
    def test_update_accumulate(self):
91
        steps = 10
92
        for i in range(steps):
93
            label, logits, pred, _ = self.get_random_case()
94
            expected_result = self.np_metrics.compute(pred, label)
95
            result = self.metrics.compute(paddle.to_tensor(logits), paddle.to_tensor(label))
96
            self.metrics.update(result.numpy())
97
            self.np_metrics.update(expected_result)
98
        self.check_output_equal(self.metrics.accumulate(), self.np_metrics.accumulate())
99

100

101
if __name__ == "__main__":
102
    unittest.main()
103

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.