datasets

Форк
0
/
test_beam.py 
153 строки · 6.4 Кб
1
import os
2
import tempfile
3
from functools import partial
4
from unittest import TestCase
5
from unittest.mock import patch
6

7
import datasets
8
import datasets.config
9

10
from .utils import require_beam
11

12

13
class DummyBeamDataset(datasets.BeamBasedBuilder):
14
    """Dummy beam dataset."""
15

16
    def _info(self):
17
        return datasets.DatasetInfo(
18
            features=datasets.Features({"content": datasets.Value("string")}),
19
            # No default supervised_keys.
20
            supervised_keys=None,
21
        )
22

23
    def _split_generators(self, dl_manager, pipeline):
24
        return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"examples": get_test_dummy_examples()})]
25

26
    def _build_pcollection(self, pipeline, examples):
27
        import apache_beam as beam
28

29
        return pipeline | "Load Examples" >> beam.Create(examples)
30

31

32
class NestedBeamDataset(datasets.BeamBasedBuilder):
33
    """Dummy beam dataset."""
34

35
    def _info(self):
36
        return datasets.DatasetInfo(
37
            features=datasets.Features({"a": datasets.Sequence({"b": datasets.Value("string")})}),
38
            # No default supervised_keys.
39
            supervised_keys=None,
40
        )
41

42
    def _split_generators(self, dl_manager, pipeline):
43
        return [
44
            datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"examples": get_test_nested_examples()})
45
        ]
46

47
    def _build_pcollection(self, pipeline, examples):
48
        import apache_beam as beam
49

50
        return pipeline | "Load Examples" >> beam.Create(examples)
51

52

53
def get_test_dummy_examples():
54
    return [(i, {"content": content}) for i, content in enumerate(["foo", "bar", "foobar"])]
55

56

57
def get_test_nested_examples():
58
    return [(i, {"a": {"b": [content]}}) for i, content in enumerate(["foo", "bar", "foobar"])]
59

60

61
class BeamBuilderTest(TestCase):
62
    @require_beam
63
    def test_download_and_prepare(self):
64
        expected_num_examples = len(get_test_dummy_examples())
65
        with tempfile.TemporaryDirectory() as tmp_cache_dir:
66
            builder = DummyBeamDataset(cache_dir=tmp_cache_dir, beam_runner="DirectRunner")
67
            builder.download_and_prepare()
68
            self.assertTrue(
69
                os.path.exists(
70
                    os.path.join(tmp_cache_dir, builder.name, "default", "0.0.0", f"{builder.name}-train.arrow")
71
                )
72
            )
73
            self.assertDictEqual(builder.info.features, datasets.Features({"content": datasets.Value("string")}))
74
            dset = builder.as_dataset()
75
            self.assertEqual(dset["train"].num_rows, expected_num_examples)
76
            self.assertEqual(dset["train"].info.splits["train"].num_examples, expected_num_examples)
77
            self.assertDictEqual(dset["train"][0], get_test_dummy_examples()[0][1])
78
            self.assertDictEqual(
79
                dset["train"][expected_num_examples - 1], get_test_dummy_examples()[expected_num_examples - 1][1]
80
            )
81
            self.assertTrue(
82
                os.path.exists(os.path.join(tmp_cache_dir, builder.name, "default", "0.0.0", "dataset_info.json"))
83
            )
84
            del dset
85

86
    @require_beam
87
    def test_download_and_prepare_sharded(self):
88
        import apache_beam as beam
89

90
        original_write_parquet = beam.io.parquetio.WriteToParquet
91

92
        expected_num_examples = len(get_test_dummy_examples())
93
        with tempfile.TemporaryDirectory() as tmp_cache_dir:
94
            builder = DummyBeamDataset(cache_dir=tmp_cache_dir, beam_runner="DirectRunner")
95
            with patch("apache_beam.io.parquetio.WriteToParquet") as write_parquet_mock:
96
                write_parquet_mock.side_effect = partial(original_write_parquet, num_shards=2)
97
                builder.download_and_prepare()
98
            self.assertTrue(
99
                os.path.exists(
100
                    os.path.join(
101
                        tmp_cache_dir, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00002.arrow"
102
                    )
103
                )
104
            )
105
            self.assertTrue(
106
                os.path.exists(
107
                    os.path.join(
108
                        tmp_cache_dir, builder.name, "default", "0.0.0", f"{builder.name}-train-00001-of-00002.arrow"
109
                    )
110
                )
111
            )
112
            self.assertDictEqual(builder.info.features, datasets.Features({"content": datasets.Value("string")}))
113
            dset = builder.as_dataset()
114
            self.assertEqual(dset["train"].num_rows, expected_num_examples)
115
            self.assertEqual(dset["train"].info.splits["train"].num_examples, expected_num_examples)
116
            # Order is not preserved when sharding, so we just check that all the elements are there
117
            self.assertListEqual(sorted(dset["train"]["content"]), sorted(["foo", "bar", "foobar"]))
118
            self.assertTrue(
119
                os.path.exists(os.path.join(tmp_cache_dir, builder.name, "default", "0.0.0", "dataset_info.json"))
120
            )
121
            del dset
122

123
    @require_beam
124
    def test_no_beam_options(self):
125
        with tempfile.TemporaryDirectory() as tmp_cache_dir:
126
            builder = DummyBeamDataset(cache_dir=tmp_cache_dir)
127
            self.assertRaises(datasets.builder.MissingBeamOptions, builder.download_and_prepare)
128

129
    @require_beam
130
    def test_nested_features(self):
131
        expected_num_examples = len(get_test_nested_examples())
132
        with tempfile.TemporaryDirectory() as tmp_cache_dir:
133
            builder = NestedBeamDataset(cache_dir=tmp_cache_dir, beam_runner="DirectRunner")
134
            builder.download_and_prepare()
135
            self.assertTrue(
136
                os.path.exists(
137
                    os.path.join(tmp_cache_dir, builder.name, "default", "0.0.0", f"{builder.name}-train.arrow")
138
                )
139
            )
140
            self.assertDictEqual(
141
                builder.info.features, datasets.Features({"a": datasets.Sequence({"b": datasets.Value("string")})})
142
            )
143
            dset = builder.as_dataset()
144
            self.assertEqual(dset["train"].num_rows, expected_num_examples)
145
            self.assertEqual(dset["train"].info.splits["train"].num_examples, expected_num_examples)
146
            self.assertDictEqual(dset["train"][0], get_test_nested_examples()[0][1])
147
            self.assertDictEqual(
148
                dset["train"][expected_num_examples - 1], get_test_nested_examples()[expected_num_examples - 1][1]
149
            )
150
            self.assertTrue(
151
                os.path.exists(os.path.join(tmp_cache_dir, builder.name, "default", "0.0.0", "dataset_info.json"))
152
            )
153
            del dset
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.