3
from functools import partial
4
from unittest import TestCase
5
from unittest.mock import patch
10
from .utils import require_beam
13
class DummyBeamDataset(datasets.BeamBasedBuilder):
14
"""Dummy beam dataset."""
17
return datasets.DatasetInfo(
18
features=datasets.Features({"content": datasets.Value("string")}),
23
def _split_generators(self, dl_manager, pipeline):
24
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"examples": get_test_dummy_examples()})]
26
def _build_pcollection(self, pipeline, examples):
27
import apache_beam as beam
29
return pipeline | "Load Examples" >> beam.Create(examples)
32
class NestedBeamDataset(datasets.BeamBasedBuilder):
33
"""Dummy beam dataset."""
36
return datasets.DatasetInfo(
37
features=datasets.Features({"a": datasets.Sequence({"b": datasets.Value("string")})}),
42
def _split_generators(self, dl_manager, pipeline):
44
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"examples": get_test_nested_examples()})
47
def _build_pcollection(self, pipeline, examples):
48
import apache_beam as beam
50
return pipeline | "Load Examples" >> beam.Create(examples)
53
def get_test_dummy_examples():
54
return [(i, {"content": content}) for i, content in enumerate(["foo", "bar", "foobar"])]
57
def get_test_nested_examples():
58
return [(i, {"a": {"b": [content]}}) for i, content in enumerate(["foo", "bar", "foobar"])]
61
class BeamBuilderTest(TestCase):
63
def test_download_and_prepare(self):
64
expected_num_examples = len(get_test_dummy_examples())
65
with tempfile.TemporaryDirectory() as tmp_cache_dir:
66
builder = DummyBeamDataset(cache_dir=tmp_cache_dir, beam_runner="DirectRunner")
67
builder.download_and_prepare()
70
os.path.join(tmp_cache_dir, builder.name, "default", "0.0.0", f"{builder.name}-train.arrow")
73
self.assertDictEqual(builder.info.features, datasets.Features({"content": datasets.Value("string")}))
74
dset = builder.as_dataset()
75
self.assertEqual(dset["train"].num_rows, expected_num_examples)
76
self.assertEqual(dset["train"].info.splits["train"].num_examples, expected_num_examples)
77
self.assertDictEqual(dset["train"][0], get_test_dummy_examples()[0][1])
79
dset["train"][expected_num_examples - 1], get_test_dummy_examples()[expected_num_examples - 1][1]
82
os.path.exists(os.path.join(tmp_cache_dir, builder.name, "default", "0.0.0", "dataset_info.json"))
87
def test_download_and_prepare_sharded(self):
88
import apache_beam as beam
90
original_write_parquet = beam.io.parquetio.WriteToParquet
92
expected_num_examples = len(get_test_dummy_examples())
93
with tempfile.TemporaryDirectory() as tmp_cache_dir:
94
builder = DummyBeamDataset(cache_dir=tmp_cache_dir, beam_runner="DirectRunner")
95
with patch("apache_beam.io.parquetio.WriteToParquet") as write_parquet_mock:
96
write_parquet_mock.side_effect = partial(original_write_parquet, num_shards=2)
97
builder.download_and_prepare()
101
tmp_cache_dir, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00002.arrow"
108
tmp_cache_dir, builder.name, "default", "0.0.0", f"{builder.name}-train-00001-of-00002.arrow"
112
self.assertDictEqual(builder.info.features, datasets.Features({"content": datasets.Value("string")}))
113
dset = builder.as_dataset()
114
self.assertEqual(dset["train"].num_rows, expected_num_examples)
115
self.assertEqual(dset["train"].info.splits["train"].num_examples, expected_num_examples)
117
self.assertListEqual(sorted(dset["train"]["content"]), sorted(["foo", "bar", "foobar"]))
119
os.path.exists(os.path.join(tmp_cache_dir, builder.name, "default", "0.0.0", "dataset_info.json"))
124
def test_no_beam_options(self):
125
with tempfile.TemporaryDirectory() as tmp_cache_dir:
126
builder = DummyBeamDataset(cache_dir=tmp_cache_dir)
127
self.assertRaises(datasets.builder.MissingBeamOptions, builder.download_and_prepare)
130
def test_nested_features(self):
131
expected_num_examples = len(get_test_nested_examples())
132
with tempfile.TemporaryDirectory() as tmp_cache_dir:
133
builder = NestedBeamDataset(cache_dir=tmp_cache_dir, beam_runner="DirectRunner")
134
builder.download_and_prepare()
137
os.path.join(tmp_cache_dir, builder.name, "default", "0.0.0", f"{builder.name}-train.arrow")
140
self.assertDictEqual(
141
builder.info.features, datasets.Features({"a": datasets.Sequence({"b": datasets.Value("string")})})
143
dset = builder.as_dataset()
144
self.assertEqual(dset["train"].num_rows, expected_num_examples)
145
self.assertEqual(dset["train"].info.splits["train"].num_examples, expected_num_examples)
146
self.assertDictEqual(dset["train"][0], get_test_nested_examples()[0][1])
147
self.assertDictEqual(
148
dset["train"][expected_num_examples - 1], get_test_nested_examples()[expected_num_examples - 1][1]
151
os.path.exists(os.path.join(tmp_cache_dir, builder.name, "default", "0.0.0", "dataset_info.json"))