paddlenlp

Форк
0
/
test_sampler.py 
146 строк · 5.9 Кб
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import unittest
17

18
from paddlenlp.data import SamplerHelper
19
from paddlenlp.datasets import load_dataset
20
from tests.common_test import CpuCommonTest
21
from tests.testing_utils import assert_raises, get_tests_dir
22

23

24
def cmp(x, y):
25
    return -1 if x < y else 1 if x > y else 0
26

27

28
class TestSampler(CpuCommonTest):
29
    @classmethod
30
    def setUpClass(cls):
31
        fixture_path = get_tests_dir(os.path.join("fixtures", "dummy"))
32
        cls.train_ds = load_dataset("clue", "tnews", data_files=[os.path.join(fixture_path, "tnews", "train.json")])
33

34
    def test_length(self):
35
        train_batch_sampler = SamplerHelper(self.train_ds)
36
        self.check_output_equal(len(train_batch_sampler), 10)
37
        self.check_output_equal(len(train_batch_sampler), train_batch_sampler.length)
38

39
        train_batch_sampler.length = 5
40
        self.check_output_equal(len(train_batch_sampler), 5)
41

42
    def test_iter1(self):
43
        train_ds_len = len(self.train_ds)
44
        ds_iter = iter(range(train_ds_len - 1, -1, -1))
45
        train_batch_sampler = SamplerHelper(self.train_ds, ds_iter)
46
        for i, sample in enumerate(train_batch_sampler):
47
            self.check_output_equal(i, train_ds_len - 1 - sample)
48

49
    def test_iter2(self):
50
        train_batch_sampler = SamplerHelper(self.train_ds)
51
        for i, sample in enumerate(train_batch_sampler):
52
            self.check_output_equal(i, sample)
53

54
    def test_list(self):
55
        train_batch_sampler = SamplerHelper(self.train_ds)
56
        list_sampler = train_batch_sampler.list()
57
        self.check_output_equal(type(iter(list_sampler)).__name__, "list_iterator")
58
        for i, sample in enumerate(list_sampler):
59
            self.check_output_equal(i, sample)
60

61
    def test_shuffle_no_buffer_size(self):
62
        train_batch_sampler = SamplerHelper(self.train_ds)
63
        shuffle_sampler = train_batch_sampler.shuffle(seed=102)
64
        expected_result = {0: 4, 1: 9}
65
        for i, sample in enumerate(shuffle_sampler):
66
            if i in expected_result.keys():
67
                self.check_output_equal(sample, expected_result[i])
68

69
    def test_shuffle_buffer_size(self):
70
        train_batch_sampler = SamplerHelper(self.train_ds)
71
        shuffle_sampler = train_batch_sampler.shuffle(buffer_size=10, seed=102)
72
        expected_result = {0: 4, 1: 9}
73
        for i, sample in enumerate(shuffle_sampler):
74
            if i in expected_result.keys():
75
                self.check_output_equal(sample, expected_result[i])
76

77
    def test_sort_buffer_size(self):
78
        train_ds_len = len(self.train_ds)
79
        ds_iter = iter(range(train_ds_len - 1, -1, -1))
80
        train_batch_sampler = SamplerHelper(self.train_ds, ds_iter)
81
        sort_sampler = train_batch_sampler.sort(cmp=lambda x, y, dataset: cmp(x, y), buffer_size=5)
82
        for i, sample in enumerate(sort_sampler):
83
            if i < 5:
84
                self.check_output_equal(i + 5, sample)
85
            else:
86
                self.check_output_equal(i - 5, sample)
87

88
    def test_sort_no_buffer_size(self):
89
        train_ds_len = len(self.train_ds)
90
        ds_iter = iter(range(train_ds_len - 1, -1, -1))
91
        train_batch_sampler = SamplerHelper(self.train_ds, ds_iter)
92
        sort_sampler = train_batch_sampler.sort(cmp=lambda x, y, dataset: cmp(x, y))
93
        for i, sample in enumerate(sort_sampler):
94
            self.check_output_equal(i, sample)
95

96
    def test_batch(self):
97
        train_batch_sampler = SamplerHelper(self.train_ds)
98
        batch_size = 3
99
        batch_sampler = train_batch_sampler.batch(batch_size)
100
        for i, sample in enumerate(batch_sampler):
101
            for j, minibatch in enumerate(sample):
102
                self.check_output_equal(i * batch_size + j, minibatch)
103

104
    @assert_raises(ValueError)
105
    def test_batch_oversize(self):
106
        train_batch_sampler = SamplerHelper(self.train_ds)
107
        batch_size = 3
108

109
        batch_sampler = train_batch_sampler.batch(
110
            batch_size,
111
            key=lambda size_so_far, minibatch_len: max(size_so_far, minibatch_len),
112
            batch_size_fn=lambda new, count, sofar, data_source: len(data_source),
113
        )
114
        for i, sample in enumerate(batch_sampler):
115
            for j, minibatch in enumerate(sample):
116
                self.check_output_equal(i * batch_size + j, minibatch)
117

118
    def test_shard(self):
119
        train_batch_sampler = SamplerHelper(self.train_ds)
120
        shard_sampler1 = train_batch_sampler.shard(2, 0)
121
        shard_sampler2 = train_batch_sampler.shard(2, 1)
122
        for i, sample in enumerate(shard_sampler1):
123
            self.check_output_equal(i * 2, sample)
124

125
        for i, sample in enumerate(shard_sampler2):
126
            self.check_output_equal(i * 2 + 1, sample)
127

128
    def test_shard_default(self):
129
        train_batch_sampler = SamplerHelper(self.train_ds)
130
        shard_sampler1 = train_batch_sampler.shard()
131
        for i, sample in enumerate(shard_sampler1):
132
            self.check_output_equal(i, sample)
133

134
    def test_apply(self):
135
        train_ds_len = len(self.train_ds)
136
        ds_iter = iter(range(train_ds_len - 1, -1, -1))
137
        train_batch_sampler = SamplerHelper(self.train_ds, ds_iter)
138
        apply_sampler = train_batch_sampler.apply(
139
            lambda sampler: SamplerHelper.sort(sampler, cmp=lambda x, y, dataset: cmp(x, y))
140
        )
141
        for i, sample in enumerate(apply_sampler):
142
            self.check_output_equal(i, sample)
143

144

145
if __name__ == "__main__":
146
    unittest.main()
147

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.