1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
19
from paddlenlp.metrics import Perplexity
20
from tests.common_test import CommonTest
21
from tests.testing_utils import cross_entropy, stable_softmax
24
class NpPerplexity(object):
27
self.total_word_num = 0
29
def compute(self, pred, label, seq_mask=None):
30
label = np.expand_dims(label, axis=2)
31
ce = cross_entropy(softmax=pred, label=label, soft_label=False, axis=-1, ignore_index=-100)
32
ce = np.squeeze(ce, axis=2)
33
if seq_mask is not None:
35
word_num = np.sum(seq_mask)
40
self.total_ce += np.sum(ce)
41
self.total_word_num += ce.size
44
return np.exp(self.total_ce / self.total_word_num)
47
class TestPerplexity(CommonTest):
49
self.config["name"] = "test_perplexity"
51
self.shape = (5, 20, self.cls_num)
52
self.label_shape = (5, 20)
53
self.metrics = Perplexity(**self.config)
54
self.np_metrics = NpPerplexity()
56
def get_random_case(self):
57
label = np.random.randint(self.cls_num, size=self.label_shape).astype("int64")
58
logits = np.random.uniform(0.1, 1.0, self.shape).astype(paddle.get_default_dtype())
59
pred = np.apply_along_axis(stable_softmax, -1, logits)
60
seq_mask = np.random.randint(2, size=self.label_shape).astype("int64")
61
return label, logits, pred, seq_mask
64
self.check_output_equal(self.metrics.name(), self.config["name"])
66
def test_compute(self):
67
label, logits, pred, _ = self.get_random_case()
68
expected_result = self.np_metrics.compute(pred, label)
69
result = self.metrics.compute(paddle.to_tensor(logits), paddle.to_tensor(label))
70
self.check_output_equal(expected_result, result.numpy())
72
def test_compute_with_mask(self):
73
label, logits, pred, seq_mask = self.get_random_case()
74
expected_result = self.np_metrics.compute(pred, label, seq_mask)
75
result = self.metrics.compute(paddle.to_tensor(logits), paddle.to_tensor(label), paddle.to_tensor(seq_mask))
76
self.check_output_equal(expected_result[0], result[0].numpy())
77
self.check_output_equal(expected_result[1], result[1])
80
label, logits, pred, _ = self.get_random_case()
81
result = self.metrics.compute(paddle.to_tensor(logits), paddle.to_tensor(label))
82
self.metrics.update(result.numpy())
83
self.check_output_not_equal(self.metrics.total_ce, 0)
84
self.check_output_not_equal(self.metrics.total_word_num, 0)
87
self.check_output_equal(self.metrics.total_ce, 0)
88
self.check_output_equal(self.metrics.total_word_num, 0)
90
def test_update_accumulate(self):
92
for i in range(steps):
93
label, logits, pred, _ = self.get_random_case()
94
expected_result = self.np_metrics.compute(pred, label)
95
result = self.metrics.compute(paddle.to_tensor(logits), paddle.to_tensor(label))
96
self.metrics.update(result.numpy())
97
self.np_metrics.update(expected_result)
98
self.check_output_equal(self.metrics.accumulate(), self.np_metrics.accumulate())
101
if __name__ == "__main__":