1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
18
from paddlenlp.data import SamplerHelper
19
from paddlenlp.datasets import load_dataset
20
from tests.common_test import CpuCommonTest
21
from tests.testing_utils import assert_raises, get_tests_dir
25
return -1 if x < y else 1 if x > y else 0
28
class TestSampler(CpuCommonTest):
31
fixture_path = get_tests_dir(os.path.join("fixtures", "dummy"))
32
cls.train_ds = load_dataset("clue", "tnews", data_files=[os.path.join(fixture_path, "tnews", "train.json")])
34
def test_length(self):
35
train_batch_sampler = SamplerHelper(self.train_ds)
36
self.check_output_equal(len(train_batch_sampler), 10)
37
self.check_output_equal(len(train_batch_sampler), train_batch_sampler.length)
39
train_batch_sampler.length = 5
40
self.check_output_equal(len(train_batch_sampler), 5)
43
train_ds_len = len(self.train_ds)
44
ds_iter = iter(range(train_ds_len - 1, -1, -1))
45
train_batch_sampler = SamplerHelper(self.train_ds, ds_iter)
46
for i, sample in enumerate(train_batch_sampler):
47
self.check_output_equal(i, train_ds_len - 1 - sample)
50
train_batch_sampler = SamplerHelper(self.train_ds)
51
for i, sample in enumerate(train_batch_sampler):
52
self.check_output_equal(i, sample)
55
train_batch_sampler = SamplerHelper(self.train_ds)
56
list_sampler = train_batch_sampler.list()
57
self.check_output_equal(type(iter(list_sampler)).__name__, "list_iterator")
58
for i, sample in enumerate(list_sampler):
59
self.check_output_equal(i, sample)
61
def test_shuffle_no_buffer_size(self):
62
train_batch_sampler = SamplerHelper(self.train_ds)
63
shuffle_sampler = train_batch_sampler.shuffle(seed=102)
64
expected_result = {0: 4, 1: 9}
65
for i, sample in enumerate(shuffle_sampler):
66
if i in expected_result.keys():
67
self.check_output_equal(sample, expected_result[i])
69
def test_shuffle_buffer_size(self):
70
train_batch_sampler = SamplerHelper(self.train_ds)
71
shuffle_sampler = train_batch_sampler.shuffle(buffer_size=10, seed=102)
72
expected_result = {0: 4, 1: 9}
73
for i, sample in enumerate(shuffle_sampler):
74
if i in expected_result.keys():
75
self.check_output_equal(sample, expected_result[i])
77
def test_sort_buffer_size(self):
78
train_ds_len = len(self.train_ds)
79
ds_iter = iter(range(train_ds_len - 1, -1, -1))
80
train_batch_sampler = SamplerHelper(self.train_ds, ds_iter)
81
sort_sampler = train_batch_sampler.sort(cmp=lambda x, y, dataset: cmp(x, y), buffer_size=5)
82
for i, sample in enumerate(sort_sampler):
84
self.check_output_equal(i + 5, sample)
86
self.check_output_equal(i - 5, sample)
88
def test_sort_no_buffer_size(self):
89
train_ds_len = len(self.train_ds)
90
ds_iter = iter(range(train_ds_len - 1, -1, -1))
91
train_batch_sampler = SamplerHelper(self.train_ds, ds_iter)
92
sort_sampler = train_batch_sampler.sort(cmp=lambda x, y, dataset: cmp(x, y))
93
for i, sample in enumerate(sort_sampler):
94
self.check_output_equal(i, sample)
97
train_batch_sampler = SamplerHelper(self.train_ds)
99
batch_sampler = train_batch_sampler.batch(batch_size)
100
for i, sample in enumerate(batch_sampler):
101
for j, minibatch in enumerate(sample):
102
self.check_output_equal(i * batch_size + j, minibatch)
104
@assert_raises(ValueError)
105
def test_batch_oversize(self):
106
train_batch_sampler = SamplerHelper(self.train_ds)
109
batch_sampler = train_batch_sampler.batch(
111
key=lambda size_so_far, minibatch_len: max(size_so_far, minibatch_len),
112
batch_size_fn=lambda new, count, sofar, data_source: len(data_source),
114
for i, sample in enumerate(batch_sampler):
115
for j, minibatch in enumerate(sample):
116
self.check_output_equal(i * batch_size + j, minibatch)
118
def test_shard(self):
119
train_batch_sampler = SamplerHelper(self.train_ds)
120
shard_sampler1 = train_batch_sampler.shard(2, 0)
121
shard_sampler2 = train_batch_sampler.shard(2, 1)
122
for i, sample in enumerate(shard_sampler1):
123
self.check_output_equal(i * 2, sample)
125
for i, sample in enumerate(shard_sampler2):
126
self.check_output_equal(i * 2 + 1, sample)
128
def test_shard_default(self):
129
train_batch_sampler = SamplerHelper(self.train_ds)
130
shard_sampler1 = train_batch_sampler.shard()
131
for i, sample in enumerate(shard_sampler1):
132
self.check_output_equal(i, sample)
134
def test_apply(self):
135
train_ds_len = len(self.train_ds)
136
ds_iter = iter(range(train_ds_len - 1, -1, -1))
137
train_batch_sampler = SamplerHelper(self.train_ds, ds_iter)
138
apply_sampler = train_batch_sampler.apply(
139
lambda sampler: SamplerHelper.sort(sampler, cmp=lambda x, y, dataset: cmp(x, y))
141
for i, sample in enumerate(apply_sampler):
142
self.check_output_equal(i, sample)
145
if __name__ == "__main__":