1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
19
from paddlenlp.data import Dict, Pad, Stack, Tuple
20
from tests import testing_utils
21
from tests.common_test import CpuCommonTest
24
class TestStack(CpuCommonTest):
26
self.input = [[1, 2, 3, 4], [4, 5, 6, 8], [8, 9, 1, 2]]
27
self.expected_result = np.array(self.input)
30
result = Stack()(self.input)
31
self.check_output_equal(self.expected_result, result)
34
class TestPad(CpuCommonTest):
36
self.input = [[1, 2, 3, 4], [4, 5, 6], [8, 9]]
37
self.expected_result = np.array([[1, 2, 3, 4], [4, 5, 6, 0], [8, 9, 0, 0]])
40
result = Pad()(self.input)
41
self.check_output_equal(self.expected_result, result)
44
class TestPadLeft(CpuCommonTest):
46
self.input = [[1, 2, 3, 4], [4, 5, 6], [8, 9]]
47
self.expected_result = np.array([[1, 2, 3, 4], [0, 4, 5, 6], [0, 0, 8, 9]])
50
result = Pad(pad_right=False)(self.input)
51
self.check_output_equal(self.expected_result, result)
54
class TestPadRetLength(CpuCommonTest):
56
self.input = [[1, 2, 3, 4], [4, 5, 6], [8, 9]]
57
self.expected_result = np.array([[1, 2, 3, 4], [4, 5, 6, 0], [8, 9, 0, 0]])
60
result, length = Pad(ret_length=True)(self.input)
61
self.check_output_equal(self.expected_result, result)
62
self.check_output_equal(length, np.array([4, 3, 2]))
65
class TestTuple(CpuCommonTest):
67
self.input = [[[1, 2, 3, 4], [1, 2, 3, 4]], [[4, 5, 6, 8], [4, 5, 6]], [[8, 9, 1, 2], [8, 9]]]
68
self.expected_result = (
69
np.array([[1, 2, 3, 4], [4, 5, 6, 8], [8, 9, 1, 2]]),
70
np.array([[1, 2, 3, 4], [4, 5, 6, 0], [8, 9, 0, 0]]),
73
def _test_impl(self, list_fn=True):
75
batchify_fn = Tuple([Stack(), Pad(axis=0, pad_val=0)])
77
batchify_fn = Tuple(Stack(), Pad(axis=0, pad_val=0))
78
result = batchify_fn(self.input)
79
self.check_output_equal(result[0], self.expected_result[0])
80
self.check_output_equal(result[1], self.expected_result[1])
85
def test_tuple_list(self):
86
self._test_impl(False)
88
@testing_utils.assert_raises
89
def test_empty_fn(self):
90
Tuple([Stack()], Pad(axis=0, pad_val=0))
93
class TestDict(CpuCommonTest):
96
{"text": [1, 2, 3, 4], "label": [1]},
97
{"text": [4, 5, 6], "label": [0]},
98
{"text": [7, 8], "label": [1]},
100
self.expected_result = (np.array([[1, 2, 3, 4], [4, 5, 6, 0], [7, 8, 0, 0]]), np.array([[1], [0], [1]]))
103
batchify_fn = Dict({"text": Pad(axis=0, pad_val=0), "label": Stack()})
104
result = batchify_fn(self.input)
105
self.check_output_equal(result[0], self.expected_result[0])
106
self.check_output_equal(result[1], self.expected_result[1])
109
if __name__ == "__main__":