paddlenlp

Форк
0
/
test_collate.py 
110 строк · 3.7 Кб
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import unittest
16

17
import numpy as np
18

19
from paddlenlp.data import Dict, Pad, Stack, Tuple
20
from tests import testing_utils
21
from tests.common_test import CpuCommonTest
22

23

24
class TestStack(CpuCommonTest):
25
    def setUp(self):
26
        self.input = [[1, 2, 3, 4], [4, 5, 6, 8], [8, 9, 1, 2]]
27
        self.expected_result = np.array(self.input)
28

29
    def test_stack(self):
30
        result = Stack()(self.input)
31
        self.check_output_equal(self.expected_result, result)
32

33

34
class TestPad(CpuCommonTest):
35
    def setUp(self):
36
        self.input = [[1, 2, 3, 4], [4, 5, 6], [8, 9]]
37
        self.expected_result = np.array([[1, 2, 3, 4], [4, 5, 6, 0], [8, 9, 0, 0]])
38

39
    def test_pad(self):
40
        result = Pad()(self.input)
41
        self.check_output_equal(self.expected_result, result)
42

43

44
class TestPadLeft(CpuCommonTest):
45
    def setUp(self):
46
        self.input = [[1, 2, 3, 4], [4, 5, 6], [8, 9]]
47
        self.expected_result = np.array([[1, 2, 3, 4], [0, 4, 5, 6], [0, 0, 8, 9]])
48

49
    def test_pad(self):
50
        result = Pad(pad_right=False)(self.input)
51
        self.check_output_equal(self.expected_result, result)
52

53

54
class TestPadRetLength(CpuCommonTest):
55
    def setUp(self):
56
        self.input = [[1, 2, 3, 4], [4, 5, 6], [8, 9]]
57
        self.expected_result = np.array([[1, 2, 3, 4], [4, 5, 6, 0], [8, 9, 0, 0]])
58

59
    def test_pad(self):
60
        result, length = Pad(ret_length=True)(self.input)
61
        self.check_output_equal(self.expected_result, result)
62
        self.check_output_equal(length, np.array([4, 3, 2]))
63

64

65
class TestTuple(CpuCommonTest):
66
    def setUp(self):
67
        self.input = [[[1, 2, 3, 4], [1, 2, 3, 4]], [[4, 5, 6, 8], [4, 5, 6]], [[8, 9, 1, 2], [8, 9]]]
68
        self.expected_result = (
69
            np.array([[1, 2, 3, 4], [4, 5, 6, 8], [8, 9, 1, 2]]),
70
            np.array([[1, 2, 3, 4], [4, 5, 6, 0], [8, 9, 0, 0]]),
71
        )
72

73
    def _test_impl(self, list_fn=True):
74
        if list_fn:
75
            batchify_fn = Tuple([Stack(), Pad(axis=0, pad_val=0)])
76
        else:
77
            batchify_fn = Tuple(Stack(), Pad(axis=0, pad_val=0))
78
        result = batchify_fn(self.input)
79
        self.check_output_equal(result[0], self.expected_result[0])
80
        self.check_output_equal(result[1], self.expected_result[1])
81

82
    def test_tuple(self):
83
        self._test_impl()
84

85
    def test_tuple_list(self):
86
        self._test_impl(False)
87

88
    @testing_utils.assert_raises
89
    def test_empty_fn(self):
90
        Tuple([Stack()], Pad(axis=0, pad_val=0))
91

92

93
class TestDict(CpuCommonTest):
94
    def setUp(self):
95
        self.input = [
96
            {"text": [1, 2, 3, 4], "label": [1]},
97
            {"text": [4, 5, 6], "label": [0]},
98
            {"text": [7, 8], "label": [1]},
99
        ]
100
        self.expected_result = (np.array([[1, 2, 3, 4], [4, 5, 6, 0], [7, 8, 0, 0]]), np.array([[1], [0], [1]]))
101

102
    def test_dict(self):
103
        batchify_fn = Dict({"text": Pad(axis=0, pad_val=0), "label": Stack()})
104
        result = batchify_fn(self.input)
105
        self.check_output_equal(result[0], self.expected_result[0])
106
        self.check_output_equal(result[1], self.expected_result[1])
107

108

109
if __name__ == "__main__":
110
    unittest.main()
111

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.