5
from contextlib import nullcontext as does_not_raise
6
from multiprocessing import Process
7
from pathlib import Path
8
from unittest import TestCase
9
from unittest.mock import patch
13
import pyarrow.parquet as pq
15
from multiprocess.pool import Pool
17
from datasets.arrow_dataset import Dataset
18
from datasets.arrow_reader import DatasetNotOnHfGcsError
19
from datasets.arrow_writer import ArrowWriter
20
from datasets.builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
21
from datasets.dataset_dict import DatasetDict, IterableDatasetDict
22
from datasets.download.download_manager import DownloadMode
23
from datasets.features import Features, Value
24
from datasets.info import DatasetInfo, PostProcessedInfo
25
from datasets.iterable_dataset import IterableDataset
26
from datasets.load import configure_builder_class
27
from datasets.splits import Split, SplitDict, SplitGenerator, SplitInfo
28
from datasets.streaming import xjoin
29
from datasets.utils.file_utils import is_local_path
30
from datasets.utils.info_utils import VerificationMode
31
from datasets.utils.logging import INFO, get_logger
34
assert_arrow_memory_doesnt_increase,
35
assert_arrow_memory_increases,
38
set_current_working_directory_to_temp_dir,
42
class DummyBuilder(DatasetBuilder):
44
return DatasetInfo(features=Features({"text": Value("string")}))
46
def _split_generators(self, dl_manager):
47
return [SplitGenerator(name=Split.TRAIN)]
49
def _prepare_split(self, split_generator, **kwargs):
50
fname = f"{self.dataset_name}-{split_generator.name}.arrow"
51
with ArrowWriter(features=self.info.features, path=os.path.join(self._output_dir, fname)) as writer:
52
writer.write_batch({"text": ["foo"] * 100})
53
num_examples, num_bytes = writer.finalize()
54
split_generator.split_info.num_examples = num_examples
55
split_generator.split_info.num_bytes = num_bytes
58
class DummyGeneratorBasedBuilder(GeneratorBasedBuilder):
60
return DatasetInfo(features=Features({"text": Value("string")}))
62
def _split_generators(self, dl_manager):
63
return [SplitGenerator(name=Split.TRAIN)]
65
def _generate_examples(self):
67
yield i, {"text": "foo"}
70
class DummyArrowBasedBuilder(ArrowBasedBuilder):
72
return DatasetInfo(features=Features({"text": Value("string")}))
74
def _split_generators(self, dl_manager):
75
return [SplitGenerator(name=Split.TRAIN)]
77
def _generate_tables(self):
79
yield i, pa.table({"text": ["foo"] * 10})
82
class DummyBeamBasedBuilder(BeamBasedBuilder):
84
return DatasetInfo(features=Features({"text": Value("string")}))
86
def _split_generators(self, dl_manager):
87
return [SplitGenerator(name=Split.TRAIN)]
89
def _build_pcollection(self, pipeline):
90
import apache_beam as beam
94
yield f"{i}_{item}", {"text": "foo"}
96
return pipeline | "Initialize" >> beam.Create(range(10)) | "Extract content" >> beam.FlatMap(_process)
99
class DummyGeneratorBasedBuilderWithIntegers(GeneratorBasedBuilder):
101
return DatasetInfo(features=Features({"id": Value("int8")}))
103
def _split_generators(self, dl_manager):
104
return [SplitGenerator(name=Split.TRAIN)]
106
def _generate_examples(self):
111
class DummyGeneratorBasedBuilderConfig(BuilderConfig):
112
def __init__(self, content="foo", times=2, *args, **kwargs):
113
super().__init__(*args, **kwargs)
114
self.content = content
118
class DummyGeneratorBasedBuilderWithConfig(GeneratorBasedBuilder):
119
BUILDER_CONFIG_CLASS = DummyGeneratorBasedBuilderConfig
122
return DatasetInfo(features=Features({"text": Value("string")}))
124
def _split_generators(self, dl_manager):
125
return [SplitGenerator(name=Split.TRAIN)]
127
def _generate_examples(self):
129
yield i, {"text": self.config.content * self.config.times}
132
class DummyBuilderWithMultipleConfigs(DummyBuilder):
134
DummyGeneratorBasedBuilderConfig(name="a"),
135
DummyGeneratorBasedBuilderConfig(name="b"),
139
class DummyBuilderWithDefaultConfig(DummyBuilderWithMultipleConfigs):
140
DEFAULT_CONFIG_NAME = "a"
143
class DummyBuilderWithDownload(DummyBuilder):
144
def __init__(self, *args, rel_path=None, abs_path=None, **kwargs):
145
super().__init__(*args, **kwargs)
146
self._rel_path = rel_path
147
self._abs_path = abs_path
149
def _split_generators(self, dl_manager):
150
if self._rel_path is not None:
151
assert os.path.exists(dl_manager.download(self._rel_path)), "dl_manager must support relative paths"
152
if self._abs_path is not None:
153
assert os.path.exists(dl_manager.download(self._abs_path)), "dl_manager must support absolute paths"
154
return [SplitGenerator(name=Split.TRAIN)]
157
class DummyBuilderWithManualDownload(DummyBuilderWithMultipleConfigs):
159
def manual_download_instructions(self):
160
return "To use the dataset you have to download some stuff manually and pass the data path to data_dir"
162
def _split_generators(self, dl_manager):
163
if not os.path.exists(self.config.data_dir):
164
raise FileNotFoundError(f"data_dir {self.config.data_dir} doesn't exist.")
165
return [SplitGenerator(name=Split.TRAIN)]
168
class DummyArrowBasedBuilderWithShards(ArrowBasedBuilder):
170
return DatasetInfo(features=Features({"id": Value("int8"), "filepath": Value("string")}))
172
def _split_generators(self, dl_manager):
173
return [SplitGenerator(name=Split.TRAIN, gen_kwargs={"filepaths": [f"data{i}.txt" for i in range(4)]})]
175
def _generate_tables(self, filepaths):
177
for filepath in filepaths:
179
yield idx, pa.table({"id": range(10 * i, 10 * (i + 1)), "filepath": [filepath] * 10})
183
class DummyGeneratorBasedBuilderWithShards(GeneratorBasedBuilder):
185
return DatasetInfo(features=Features({"id": Value("int8"), "filepath": Value("string")}))
187
def _split_generators(self, dl_manager):
188
return [SplitGenerator(name=Split.TRAIN, gen_kwargs={"filepaths": [f"data{i}.txt" for i in range(4)]})]
190
def _generate_examples(self, filepaths):
192
for filepath in filepaths:
194
yield idx, {"id": i, "filepath": filepath}
198
class DummyArrowBasedBuilderWithAmbiguousShards(ArrowBasedBuilder):
200
return DatasetInfo(features=Features({"id": Value("int8"), "filepath": Value("string")}))
202
def _split_generators(self, dl_manager):
207
"filepaths": [f"data{i}.txt" for i in range(4)],
208
"dummy_kwarg_with_different_length": [f"dummy_data{i}.txt" for i in range(3)],
213
def _generate_tables(self, filepaths, dummy_kwarg_with_different_length):
215
for filepath in filepaths:
217
yield idx, pa.table({"id": range(10 * i, 10 * (i + 1)), "filepath": [filepath] * 10})
221
class DummyGeneratorBasedBuilderWithAmbiguousShards(GeneratorBasedBuilder):
223
return DatasetInfo(features=Features({"id": Value("int8"), "filepath": Value("string")}))
225
def _split_generators(self, dl_manager):
230
"filepaths": [f"data{i}.txt" for i in range(4)],
231
"dummy_kwarg_with_different_length": [f"dummy_data{i}.txt" for i in range(3)],
236
def _generate_examples(self, filepaths, dummy_kwarg_with_different_length):
238
for filepath in filepaths:
240
yield idx, {"id": i, "filepath": filepath}
244
def _run_concurrent_download_and_prepare(tmp_dir):
245
builder = DummyBuilder(cache_dir=tmp_dir)
246
builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
250
def check_streaming(builder):
251
builders_module = importlib.import_module(builder.__module__)
252
assert builders_module._patched_for_streaming
253
assert builders_module.os.path.join is xjoin
256
class BuilderTest(TestCase):
257
def test_download_and_prepare(self):
258
with tempfile.TemporaryDirectory() as tmp_dir:
259
builder = DummyBuilder(cache_dir=tmp_dir)
260
builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
264
tmp_dir, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.arrow"
268
self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
269
self.assertEqual(builder.info.splits["train"].num_examples, 100)
271
os.path.exists(os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
274
def test_download_and_prepare_checksum_computation(self):
275
with tempfile.TemporaryDirectory() as tmp_dir:
276
builder_no_verification = DummyBuilder(cache_dir=tmp_dir)
277
builder_no_verification.download_and_prepare(
278
try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD
281
all(v["checksum"] is not None for _, v in builder_no_verification.info.download_checksums.items())
283
builder_with_verification = DummyBuilder(cache_dir=tmp_dir)
284
builder_with_verification.download_and_prepare(
285
try_from_hf_gcs=False,
286
download_mode=DownloadMode.FORCE_REDOWNLOAD,
287
verification_mode=VerificationMode.ALL_CHECKS,
290
all(v["checksum"] is None for _, v in builder_with_verification.info.download_checksums.items())
293
def test_concurrent_download_and_prepare(self):
294
with tempfile.TemporaryDirectory() as tmp_dir:
296
with Pool(processes=processes) as pool:
298
pool.apply_async(_run_concurrent_download_and_prepare, kwds={"tmp_dir": tmp_dir})
299
for _ in range(processes)
301
builders = [job.get() for job in jobs]
302
for builder in builders:
307
builder.dataset_name,
310
f"{builder.dataset_name}-train.arrow",
314
self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
315
self.assertEqual(builder.info.splits["train"].num_examples, 100)
318
os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json")
322
def test_download_and_prepare_with_base_path(self):
323
with tempfile.TemporaryDirectory() as tmp_dir:
324
rel_path = "dummy1.data"
325
abs_path = os.path.join(tmp_dir, "dummy2.data")
327
builder = DummyBuilderWithDownload(cache_dir=tmp_dir, rel_path=rel_path)
328
with self.assertRaises(FileNotFoundError):
329
builder.download_and_prepare(
330
try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD, base_path=tmp_dir
333
builder = DummyBuilderWithDownload(cache_dir=tmp_dir, abs_path=abs_path)
334
with self.assertRaises(FileNotFoundError):
335
builder.download_and_prepare(
336
try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD, base_path=tmp_dir
339
open(os.path.join(tmp_dir, rel_path), "w")
341
builder = DummyBuilderWithDownload(cache_dir=tmp_dir, rel_path=rel_path, abs_path=abs_path)
342
builder.download_and_prepare(
343
try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD, base_path=tmp_dir
349
builder.dataset_name,
352
f"{builder.dataset_name}-train.arrow",
357
def test_as_dataset_with_post_process(self):
358
def _post_process(self, dataset, resources_paths):
359
def char_tokenize(example):
360
return {"tokens": list(example["text"])}
362
return dataset.map(char_tokenize, cache_file_name=resources_paths["tokenized_dataset"])
364
def _post_processing_resources(self, split):
365
return {"tokenized_dataset": f"tokenized_dataset-{split}.arrow"}
367
with tempfile.TemporaryDirectory() as tmp_dir:
368
builder = DummyBuilder(cache_dir=tmp_dir)
369
builder.info.post_processed = PostProcessedInfo(
370
features=Features({"text": Value("string"), "tokens": [Value("string")]})
372
builder._post_process = types.MethodType(_post_process, builder)
373
builder._post_processing_resources = types.MethodType(_post_processing_resources, builder)
374
os.makedirs(builder.cache_dir)
376
builder.info.splits = SplitDict()
377
builder.info.splits.add(SplitInfo("train", num_examples=10))
378
builder.info.splits.add(SplitInfo("test", num_examples=10))
380
for split in builder.info.splits:
382
path=os.path.join(builder.cache_dir, f"{builder.dataset_name}-{split}.arrow"),
383
features=Features({"text": Value("string")}),
385
writer.write_batch({"text": ["foo"] * 10})
389
path=os.path.join(builder.cache_dir, f"tokenized_dataset-{split}.arrow"),
390
features=Features({"text": Value("string"), "tokens": [Value("string")]}),
392
writer.write_batch({"text": ["foo"] * 10, "tokens": [list("foo")] * 10})
395
dsets = builder.as_dataset()
396
self.assertIsInstance(dsets, DatasetDict)
397
self.assertListEqual(list(dsets.keys()), ["train", "test"])
398
self.assertEqual(len(dsets["train"]), 10)
399
self.assertEqual(len(dsets["test"]), 10)
400
self.assertDictEqual(
401
dsets["train"].features, Features({"text": Value("string"), "tokens": [Value("string")]})
403
self.assertDictEqual(
404
dsets["test"].features, Features({"text": Value("string"), "tokens": [Value("string")]})
406
self.assertListEqual(dsets["train"].column_names, ["text", "tokens"])
407
self.assertListEqual(dsets["test"].column_names, ["text", "tokens"])
410
dset = builder.as_dataset("train")
411
self.assertIsInstance(dset, Dataset)
412
self.assertEqual(dset.split, "train")
413
self.assertEqual(len(dset), 10)
414
self.assertDictEqual(dset.features, Features({"text": Value("string"), "tokens": [Value("string")]}))
415
self.assertListEqual(dset.column_names, ["text", "tokens"])
416
self.assertGreater(builder.info.post_processing_size, 0)
418
builder.info.post_processed.resources_checksums["train"]["tokenized_dataset"]["num_bytes"], 0
422
dset = builder.as_dataset("train+test[:30%]")
423
self.assertIsInstance(dset, Dataset)
424
self.assertEqual(dset.split, "train+test[:30%]")
425
self.assertEqual(len(dset), 13)
426
self.assertDictEqual(dset.features, Features({"text": Value("string"), "tokens": [Value("string")]}))
427
self.assertListEqual(dset.column_names, ["text", "tokens"])
430
dset = builder.as_dataset("all")
431
self.assertIsInstance(dset, Dataset)
432
self.assertEqual(dset.split, "train+test")
433
self.assertEqual(len(dset), 20)
434
self.assertDictEqual(dset.features, Features({"text": Value("string"), "tokens": [Value("string")]}))
435
self.assertListEqual(dset.column_names, ["text", "tokens"])
438
def _post_process(self, dataset, resources_paths):
439
return dataset.select([0, 1], keep_in_memory=True)
441
with tempfile.TemporaryDirectory() as tmp_dir:
442
builder = DummyBuilder(cache_dir=tmp_dir)
443
builder._post_process = types.MethodType(_post_process, builder)
444
os.makedirs(builder.cache_dir)
446
builder.info.splits = SplitDict()
447
builder.info.splits.add(SplitInfo("train", num_examples=10))
448
builder.info.splits.add(SplitInfo("test", num_examples=10))
450
for split in builder.info.splits:
452
path=os.path.join(builder.cache_dir, f"{builder.dataset_name}-{split}.arrow"),
453
features=Features({"text": Value("string")}),
455
writer.write_batch({"text": ["foo"] * 10})
459
path=os.path.join(builder.cache_dir, f"small_dataset-{split}.arrow"),
460
features=Features({"text": Value("string")}),
462
writer.write_batch({"text": ["foo"] * 2})
465
dsets = builder.as_dataset()
466
self.assertIsInstance(dsets, DatasetDict)
467
self.assertListEqual(list(dsets.keys()), ["train", "test"])
468
self.assertEqual(len(dsets["train"]), 2)
469
self.assertEqual(len(dsets["test"]), 2)
470
self.assertDictEqual(dsets["train"].features, Features({"text": Value("string")}))
471
self.assertDictEqual(dsets["test"].features, Features({"text": Value("string")}))
472
self.assertListEqual(dsets["train"].column_names, ["text"])
473
self.assertListEqual(dsets["test"].column_names, ["text"])
476
dset = builder.as_dataset("train")
477
self.assertIsInstance(dset, Dataset)
478
self.assertEqual(dset.split, "train")
479
self.assertEqual(len(dset), 2)
480
self.assertDictEqual(dset.features, Features({"text": Value("string")}))
481
self.assertListEqual(dset.column_names, ["text"])
484
dset = builder.as_dataset("train+test[:30%]")
485
self.assertIsInstance(dset, Dataset)
486
self.assertEqual(dset.split, "train+test[:30%]")
487
self.assertEqual(len(dset), 2)
488
self.assertDictEqual(dset.features, Features({"text": Value("string")}))
489
self.assertListEqual(dset.column_names, ["text"])
493
def test_as_dataset_with_post_process_with_index(self):
494
def _post_process(self, dataset, resources_paths):
495
if os.path.exists(resources_paths["index"]):
496
dataset.load_faiss_index("my_index", resources_paths["index"])
499
dataset.add_faiss_index_from_external_arrays(
500
external_arrays=np.ones((len(dataset), 8)), string_factory="Flat", index_name="my_index"
502
dataset.save_faiss_index("my_index", resources_paths["index"])
505
def _post_processing_resources(self, split):
506
return {"index": f"Flat-{split}.faiss"}
508
with tempfile.TemporaryDirectory() as tmp_dir:
509
builder = DummyBuilder(cache_dir=tmp_dir)
510
builder._post_process = types.MethodType(_post_process, builder)
511
builder._post_processing_resources = types.MethodType(_post_processing_resources, builder)
512
os.makedirs(builder.cache_dir)
514
builder.info.splits = SplitDict()
515
builder.info.splits.add(SplitInfo("train", num_examples=10))
516
builder.info.splits.add(SplitInfo("test", num_examples=10))
518
for split in builder.info.splits:
520
path=os.path.join(builder.cache_dir, f"{builder.dataset_name}-{split}.arrow"),
521
features=Features({"text": Value("string")}),
523
writer.write_batch({"text": ["foo"] * 10})
527
path=os.path.join(builder.cache_dir, f"small_dataset-{split}.arrow"),
528
features=Features({"text": Value("string")}),
530
writer.write_batch({"text": ["foo"] * 2})
533
dsets = builder.as_dataset()
534
self.assertIsInstance(dsets, DatasetDict)
535
self.assertListEqual(list(dsets.keys()), ["train", "test"])
536
self.assertEqual(len(dsets["train"]), 10)
537
self.assertEqual(len(dsets["test"]), 10)
538
self.assertDictEqual(dsets["train"].features, Features({"text": Value("string")}))
539
self.assertDictEqual(dsets["test"].features, Features({"text": Value("string")}))
540
self.assertListEqual(dsets["train"].column_names, ["text"])
541
self.assertListEqual(dsets["test"].column_names, ["text"])
542
self.assertListEqual(dsets["train"].list_indexes(), ["my_index"])
543
self.assertListEqual(dsets["test"].list_indexes(), ["my_index"])
544
self.assertGreater(builder.info.post_processing_size, 0)
545
self.assertGreater(builder.info.post_processed.resources_checksums["train"]["index"]["num_bytes"], 0)
548
dset = builder.as_dataset("train")
549
self.assertIsInstance(dset, Dataset)
550
self.assertEqual(dset.split, "train")
551
self.assertEqual(len(dset), 10)
552
self.assertDictEqual(dset.features, Features({"text": Value("string")}))
553
self.assertListEqual(dset.column_names, ["text"])
554
self.assertListEqual(dset.list_indexes(), ["my_index"])
557
dset = builder.as_dataset("train+test[:30%]")
558
self.assertIsInstance(dset, Dataset)
559
self.assertEqual(dset.split, "train+test[:30%]")
560
self.assertEqual(len(dset), 13)
561
self.assertDictEqual(dset.features, Features({"text": Value("string")}))
562
self.assertListEqual(dset.column_names, ["text"])
563
self.assertListEqual(dset.list_indexes(), ["my_index"])
566
def test_download_and_prepare_with_post_process(self):
567
def _post_process(self, dataset, resources_paths):
568
def char_tokenize(example):
569
return {"tokens": list(example["text"])}
571
return dataset.map(char_tokenize, cache_file_name=resources_paths["tokenized_dataset"])
573
def _post_processing_resources(self, split):
574
return {"tokenized_dataset": f"tokenized_dataset-{split}.arrow"}
576
with tempfile.TemporaryDirectory() as tmp_dir:
577
builder = DummyBuilder(cache_dir=tmp_dir)
578
builder.info.post_processed = PostProcessedInfo(
579
features=Features({"text": Value("string"), "tokens": [Value("string")]})
581
builder._post_process = types.MethodType(_post_process, builder)
582
builder._post_processing_resources = types.MethodType(_post_processing_resources, builder)
583
builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
587
tmp_dir, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.arrow"
591
self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
592
self.assertDictEqual(
593
builder.info.post_processed.features,
594
Features({"text": Value("string"), "tokens": [Value("string")]}),
596
self.assertEqual(builder.info.splits["train"].num_examples, 100)
598
os.path.exists(os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
601
def _post_process(self, dataset, resources_paths):
602
return dataset.select([0, 1], keep_in_memory=True)
604
with tempfile.TemporaryDirectory() as tmp_dir:
605
builder = DummyBuilder(cache_dir=tmp_dir)
606
builder._post_process = types.MethodType(_post_process, builder)
607
builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
611
tmp_dir, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.arrow"
615
self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
616
self.assertIsNone(builder.info.post_processed)
617
self.assertEqual(builder.info.splits["train"].num_examples, 100)
619
os.path.exists(os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
622
def _post_process(self, dataset, resources_paths):
623
if os.path.exists(resources_paths["index"]):
624
dataset.load_faiss_index("my_index", resources_paths["index"])
627
dataset = dataset.add_faiss_index_from_external_arrays(
628
external_arrays=np.ones((len(dataset), 8)), string_factory="Flat", index_name="my_index"
630
dataset.save_faiss_index("my_index", resources_paths["index"])
633
def _post_processing_resources(self, split):
634
return {"index": f"Flat-{split}.faiss"}
636
with tempfile.TemporaryDirectory() as tmp_dir:
637
builder = DummyBuilder(cache_dir=tmp_dir)
638
builder._post_process = types.MethodType(_post_process, builder)
639
builder._post_processing_resources = types.MethodType(_post_processing_resources, builder)
640
builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
644
tmp_dir, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.arrow"
648
self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
649
self.assertIsNone(builder.info.post_processed)
650
self.assertEqual(builder.info.splits["train"].num_examples, 100)
652
os.path.exists(os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
655
def test_error_download_and_prepare(self):
656
def _prepare_split(self, split_generator, **kwargs):
659
with tempfile.TemporaryDirectory() as tmp_dir:
660
builder = DummyBuilder(cache_dir=tmp_dir)
661
builder._prepare_split = types.MethodType(_prepare_split, builder)
664
builder.download_and_prepare,
665
try_from_hf_gcs=False,
666
download_mode=DownloadMode.FORCE_REDOWNLOAD,
668
self.assertRaises(FileNotFoundError, builder.as_dataset)
670
def test_generator_based_download_and_prepare(self):
671
with tempfile.TemporaryDirectory() as tmp_dir:
672
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir)
673
builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
678
builder.dataset_name,
681
f"{builder.dataset_name}-train.arrow",
685
self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
686
self.assertEqual(builder.info.splits["train"].num_examples, 100)
688
os.path.exists(os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
692
with tempfile.TemporaryDirectory() as tmp_dir:
693
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir)
694
with patch("datasets.builder.ArrowWriter", side_effect=ArrowWriter) as mock_arrow_writer:
695
builder.download_and_prepare(
696
download_mode=DownloadMode.FORCE_REDOWNLOAD, verification_mode=VerificationMode.NO_CHECKS
698
mock_arrow_writer.assert_called_once()
699
args, kwargs = mock_arrow_writer.call_args_list[0]
700
self.assertFalse(kwargs["check_duplicates"])
702
mock_arrow_writer.reset_mock()
704
builder.download_and_prepare(
705
download_mode=DownloadMode.FORCE_REDOWNLOAD, verification_mode=VerificationMode.BASIC_CHECKS
707
mock_arrow_writer.assert_called_once()
708
args, kwargs = mock_arrow_writer.call_args_list[0]
709
self.assertTrue(kwargs["check_duplicates"])
711
def test_cache_dir_no_args(self):
712
with tempfile.TemporaryDirectory() as tmp_dir:
713
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_dir=None, data_files=None)
714
relative_cache_dir_parts = Path(builder._relative_data_dir()).parts
715
self.assertTupleEqual(relative_cache_dir_parts, (builder.dataset_name, "default", "0.0.0"))
717
def test_cache_dir_for_data_files(self):
718
with tempfile.TemporaryDirectory() as tmp_dir:
719
dummy_data1 = os.path.join(tmp_dir, "dummy_data1.txt")
720
with open(dummy_data1, "w", encoding="utf-8") as f:
721
f.writelines("foo bar")
722
dummy_data2 = os.path.join(tmp_dir, "dummy_data2.txt")
723
with open(dummy_data2, "w", encoding="utf-8") as f:
724
f.writelines("foo bar\n")
726
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=dummy_data1)
727
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=dummy_data1)
728
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
729
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data1])
730
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
731
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files={"train": dummy_data1})
732
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
733
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files={Split.TRAIN: dummy_data1})
734
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
735
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files={"train": [dummy_data1]})
736
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
737
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files={"test": dummy_data1})
738
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
739
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=dummy_data2)
740
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
741
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data2])
742
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
743
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data1, dummy_data2])
744
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
746
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data1, dummy_data2])
747
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data1, dummy_data2])
748
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
749
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data2, dummy_data1])
750
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
752
builder = DummyGeneratorBasedBuilder(
753
cache_dir=tmp_dir, data_files={"train": dummy_data1, "test": dummy_data2}
755
other_builder = DummyGeneratorBasedBuilder(
756
cache_dir=tmp_dir, data_files={"train": dummy_data1, "test": dummy_data2}
758
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
759
other_builder = DummyGeneratorBasedBuilder(
760
cache_dir=tmp_dir, data_files={"train": [dummy_data1], "test": dummy_data2}
762
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
763
other_builder = DummyGeneratorBasedBuilder(
764
cache_dir=tmp_dir, data_files={"train": dummy_data1, "validation": dummy_data2}
766
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
767
other_builder = DummyGeneratorBasedBuilder(
769
data_files={"train": [dummy_data1, dummy_data2], "test": dummy_data2},
771
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
773
def test_cache_dir_for_features(self):
774
with tempfile.TemporaryDirectory() as tmp_dir:
775
f1 = Features({"id": Value("int8")})
776
f2 = Features({"id": Value("int32")})
777
builder = DummyGeneratorBasedBuilderWithIntegers(cache_dir=tmp_dir, features=f1)
778
other_builder = DummyGeneratorBasedBuilderWithIntegers(cache_dir=tmp_dir, features=f1)
779
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
780
other_builder = DummyGeneratorBasedBuilderWithIntegers(cache_dir=tmp_dir, features=f2)
781
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
783
def test_cache_dir_for_config_kwargs(self):
784
with tempfile.TemporaryDirectory() as tmp_dir:
786
builder = DummyGeneratorBasedBuilderWithConfig(cache_dir=tmp_dir, content="foo", times=2)
787
other_builder = DummyGeneratorBasedBuilderWithConfig(cache_dir=tmp_dir, times=2, content="foo")
788
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
789
self.assertIn("content=foo", builder.cache_dir)
790
self.assertIn("times=2", builder.cache_dir)
791
other_builder = DummyGeneratorBasedBuilderWithConfig(cache_dir=tmp_dir, content="bar", times=2)
792
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
793
other_builder = DummyGeneratorBasedBuilderWithConfig(cache_dir=tmp_dir, content="foo")
794
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
796
with tempfile.TemporaryDirectory() as tmp_dir:
798
builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="a", content="foo", times=2)
799
other_builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="a", times=2, content="foo")
800
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
801
self.assertIn("content=foo", builder.cache_dir)
802
self.assertIn("times=2", builder.cache_dir)
803
other_builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="a", content="bar", times=2)
804
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
805
other_builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="a", content="foo")
806
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
808
def test_config_names(self):
809
with tempfile.TemporaryDirectory() as tmp_dir:
810
with self.assertRaises(ValueError) as error_context:
811
DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, data_files=None, data_dir=None)
812
self.assertIn("Please pick one among the available configs", str(error_context.exception))
814
builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="a")
815
self.assertEqual(builder.config.name, "a")
817
builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="b")
818
self.assertEqual(builder.config.name, "b")
820
with self.assertRaises(ValueError):
821
DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir)
823
builder = DummyBuilderWithDefaultConfig(cache_dir=tmp_dir)
824
self.assertEqual(builder.config.name, "a")
826
def test_cache_dir_for_data_dir(self):
827
with tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as data_dir:
828
builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=data_dir)
829
other_builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=data_dir)
830
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
831
other_builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=tmp_dir)
832
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
834
def test_cache_dir_for_configured_builder(self):
835
with tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as data_dir:
836
builder_cls = configure_builder_class(
837
DummyBuilderWithManualDownload,
838
builder_configs=[BuilderConfig(data_dir=data_dir)],
839
default_config_name=None,
840
dataset_name="dummy",
842
builder = builder_cls(cache_dir=tmp_dir, hash="abc")
843
other_builder = builder_cls(cache_dir=tmp_dir, hash="abc")
844
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
845
other_builder = builder_cls(cache_dir=tmp_dir, hash="def")
846
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
849
def test_arrow_based_download_and_prepare(tmp_path):
850
builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
851
builder.download_and_prepare()
852
assert os.path.exists(
855
builder.dataset_name,
858
f"{builder.dataset_name}-train.arrow",
861
assert builder.info.features, Features({"text": Value("string")})
862
assert builder.info.splits["train"].num_examples == 100
863
assert os.path.exists(os.path.join(tmp_path, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
867
def test_beam_based_download_and_prepare(tmp_path):
868
builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner")
869
builder.download_and_prepare()
870
assert os.path.exists(
873
builder.dataset_name,
876
f"{builder.dataset_name}-train.arrow",
879
assert builder.info.features, Features({"text": Value("string")})
880
assert builder.info.splits["train"].num_examples == 100
881
assert os.path.exists(os.path.join(tmp_path, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
885
def test_beam_based_as_dataset(tmp_path):
886
builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner")
887
builder.download_and_prepare()
888
dataset = builder.as_dataset()
890
assert isinstance(dataset["train"], Dataset)
891
assert len(dataset["train"]) > 0
894
@pytest.mark.parametrize(
895
"split, expected_dataset_class, expected_dataset_length",
897
(None, DatasetDict, 10),
898
("train", Dataset, 10),
899
("train+test[:30%]", Dataset, 13),
902
@pytest.mark.parametrize("in_memory", [False, True])
903
def test_builder_as_dataset(split, expected_dataset_class, expected_dataset_length, in_memory, tmp_path):
904
cache_dir = str(tmp_path)
905
builder = DummyBuilder(cache_dir=cache_dir)
906
os.makedirs(builder.cache_dir)
908
builder.info.splits = SplitDict()
909
builder.info.splits.add(SplitInfo("train", num_examples=10))
910
builder.info.splits.add(SplitInfo("test", num_examples=10))
912
for info_split in builder.info.splits:
914
path=os.path.join(builder.cache_dir, f"{builder.dataset_name}-{info_split}.arrow"),
915
features=Features({"text": Value("string")}),
917
writer.write_batch({"text": ["foo"] * 10})
920
with assert_arrow_memory_increases() if in_memory else assert_arrow_memory_doesnt_increase():
921
dataset = builder.as_dataset(split=split, in_memory=in_memory)
922
assert isinstance(dataset, expected_dataset_class)
923
if isinstance(dataset, DatasetDict):
924
assert list(dataset.keys()) == ["train", "test"]
925
datasets = dataset.values()
926
expected_splits = ["train", "test"]
927
elif isinstance(dataset, Dataset):
929
expected_splits = [split]
930
for dataset, expected_split in zip(datasets, expected_splits):
931
assert dataset.split == expected_split
932
assert len(dataset) == expected_dataset_length
933
assert dataset.features == Features({"text": Value("string")})
934
dataset.column_names == ["text"]
937
@pytest.mark.parametrize("in_memory", [False, True])
938
def test_generator_based_builder_as_dataset(in_memory, tmp_path):
939
cache_dir = tmp_path / "data"
941
cache_dir = str(cache_dir)
942
builder = DummyGeneratorBasedBuilder(cache_dir=cache_dir)
943
builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
944
with assert_arrow_memory_increases() if in_memory else assert_arrow_memory_doesnt_increase():
945
dataset = builder.as_dataset("train", in_memory=in_memory)
946
assert dataset.data.to_pydict() == {"text": ["foo"] * 100}
949
@pytest.mark.parametrize(
950
"writer_batch_size, default_writer_batch_size, expected_chunks", [(None, None, 1), (None, 5, 20), (10, None, 10)]
952
def test_custom_writer_batch_size(tmp_path, writer_batch_size, default_writer_batch_size, expected_chunks):
953
cache_dir = str(tmp_path)
954
if default_writer_batch_size:
955
DummyGeneratorBasedBuilder.DEFAULT_WRITER_BATCH_SIZE = default_writer_batch_size
956
builder = DummyGeneratorBasedBuilder(cache_dir=cache_dir, writer_batch_size=writer_batch_size)
957
assert builder._writer_batch_size == (writer_batch_size or default_writer_batch_size)
958
builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
959
dataset = builder.as_dataset("train")
960
assert len(dataset.data[0].chunks) == expected_chunks
963
def test_builder_as_streaming_dataset(tmp_path):
964
dummy_builder = DummyGeneratorBasedBuilder(cache_dir=str(tmp_path))
965
check_streaming(dummy_builder)
966
dsets = dummy_builder.as_streaming_dataset()
967
assert isinstance(dsets, IterableDatasetDict)
968
assert isinstance(dsets["train"], IterableDataset)
969
assert len(list(dsets["train"])) == 100
970
dset = dummy_builder.as_streaming_dataset(split="train")
971
assert isinstance(dset, IterableDataset)
972
assert len(list(dset)) == 100
976
def test_beam_based_builder_as_streaming_dataset(tmp_path):
977
builder = DummyBeamBasedBuilder(cache_dir=tmp_path)
978
check_streaming(builder)
979
with pytest.raises(DatasetNotOnHfGcsError):
980
builder.as_streaming_dataset()
983
def _run_test_builder_streaming_works_in_subprocesses(builder):
984
check_streaming(builder)
985
dset = builder.as_streaming_dataset(split="train")
986
assert isinstance(dset, IterableDataset)
987
assert len(list(dset)) == 100
990
def test_builder_streaming_works_in_subprocess(tmp_path):
991
dummy_builder = DummyGeneratorBasedBuilder(cache_dir=str(tmp_path))
992
p = Process(target=_run_test_builder_streaming_works_in_subprocesses, args=(dummy_builder,))
997
class DummyBuilderWithVersion(GeneratorBasedBuilder):
1001
return DatasetInfo(features=Features({"text": Value("string")}))
1003
def _split_generators(self, dl_manager):
1006
def _generate_examples(self):
1010
class DummyBuilderWithBuilderConfigs(GeneratorBasedBuilder):
1011
BUILDER_CONFIGS = [BuilderConfig(name="custom", version="2.0.0")]
1014
return DatasetInfo(features=Features({"text": Value("string")}))
1016
def _split_generators(self, dl_manager):
1019
def _generate_examples(self):
1023
class CustomBuilderConfig(BuilderConfig):
1024
def __init__(self, date=None, language=None, version="2.0.0", **kwargs):
1025
name = f"{date}.{language}"
1026
super().__init__(name=name, version=version, **kwargs)
1028
self.language = language
1031
class DummyBuilderWithCustomBuilderConfigs(GeneratorBasedBuilder):
1032
BUILDER_CONFIGS = [CustomBuilderConfig(date="20220501", language="en")]
1033
BUILDER_CONFIG_CLASS = CustomBuilderConfig
1036
return DatasetInfo(features=Features({"text": Value("string")}))
1038
def _split_generators(self, dl_manager):
1041
def _generate_examples(self):
1045
@pytest.mark.parametrize(
1046
"builder_class, kwargs",
1048
(DummyBuilderWithVersion, {}),
1049
(DummyBuilderWithBuilderConfigs, {"config_name": "custom"}),
1050
(DummyBuilderWithCustomBuilderConfigs, {"config_name": "20220501.en"}),
1051
(DummyBuilderWithCustomBuilderConfigs, {"date": "20220501", "language": "ca"}),
1054
def test_builder_config_version(builder_class, kwargs, tmp_path):
1055
cache_dir = str(tmp_path)
1056
builder = builder_class(cache_dir=cache_dir, **kwargs)
1057
assert builder.config.version == "2.0.0"
1060
def test_builder_download_and_prepare_with_absolute_output_dir(tmp_path):
1061
builder = DummyGeneratorBasedBuilder()
1062
output_dir = str(tmp_path)
1063
builder.download_and_prepare(output_dir)
1064
assert builder._output_dir.startswith(tmp_path.resolve().as_posix())
1065
assert os.path.exists(os.path.join(output_dir, "dataset_info.json"))
1066
assert os.path.exists(os.path.join(output_dir, f"{builder.dataset_name}-train.arrow"))
1067
assert not os.path.exists(os.path.join(output_dir + ".incomplete"))
1070
def test_builder_download_and_prepare_with_relative_output_dir():
1071
with set_current_working_directory_to_temp_dir():
1072
builder = DummyGeneratorBasedBuilder()
1073
output_dir = "test-out"
1074
builder.download_and_prepare(output_dir)
1075
assert Path(builder._output_dir).resolve().as_posix().startswith(Path(output_dir).resolve().as_posix())
1076
assert os.path.exists(os.path.join(output_dir, "dataset_info.json"))
1077
assert os.path.exists(os.path.join(output_dir, f"{builder.dataset_name}-train.arrow"))
1078
assert not os.path.exists(os.path.join(output_dir + ".incomplete"))
1081
def test_builder_with_filesystem_download_and_prepare(tmp_path, mockfs):
1082
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path)
1083
builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options)
1084
assert builder._output_dir.startswith("mock://my_dataset")
1085
assert is_local_path(builder._cache_downloaded_dir)
1086
assert isinstance(builder._fs, type(mockfs))
1087
assert builder._fs.storage_options == mockfs.storage_options
1088
assert mockfs.exists("my_dataset/dataset_info.json")
1089
assert mockfs.exists(f"my_dataset/{builder.dataset_name}-train.arrow")
1090
assert not mockfs.exists("my_dataset.incomplete")
1093
def test_builder_with_filesystem_download_and_prepare_reload(tmp_path, mockfs, caplog):
1094
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path)
1095
mockfs.makedirs("my_dataset")
1096
DatasetInfo().write_to_directory("mock://my_dataset", storage_options=mockfs.storage_options)
1097
mockfs.touch(f"my_dataset/{builder.dataset_name}-train.arrow")
1099
with caplog.at_level(INFO, logger=get_logger().name):
1100
builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options)
1101
assert "Found cached dataset" in caplog.text
1104
def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path):
1105
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path)
1106
builder.download_and_prepare(file_format="parquet")
1107
assert builder.info.splits["train"].num_examples == 100
1108
parquet_path = os.path.join(
1109
tmp_path, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.parquet"
1111
assert os.path.exists(parquet_path)
1112
assert pq.ParquetFile(parquet_path) is not None
1115
def test_generator_based_builder_download_and_prepare_sharded(tmp_path):
1116
writer_batch_size = 25
1117
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size)
1118
with patch("datasets.config.MAX_SHARD_SIZE", 1):
1119
builder.download_and_prepare(file_format="parquet")
1120
expected_num_shards = 100 // writer_batch_size
1121
assert builder.info.splits["train"].num_examples == 100
1122
parquet_path = os.path.join(
1124
builder.dataset_name,
1127
f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.parquet",
1129
assert os.path.exists(parquet_path)
1131
pq.ParquetFile(parquet_path)
1132
for parquet_path in Path(tmp_path).rglob(
1133
f"{builder.dataset_name}-train-*-of-{expected_num_shards:05d}.parquet"
1136
assert len(parquet_files) == expected_num_shards
1137
assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100
1140
def test_generator_based_builder_download_and_prepare_with_max_shard_size(tmp_path):
1141
writer_batch_size = 25
1142
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size)
1143
builder.download_and_prepare(file_format="parquet", max_shard_size=1)
1144
expected_num_shards = 100 // writer_batch_size
1145
assert builder.info.splits["train"].num_examples == 100
1146
parquet_path = os.path.join(
1148
builder.dataset_name,
1151
f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.parquet",
1153
assert os.path.exists(parquet_path)
1155
pq.ParquetFile(parquet_path)
1156
for parquet_path in Path(tmp_path).rglob(
1157
f"{builder.dataset_name}-train-*-of-{expected_num_shards:05d}.parquet"
1160
assert len(parquet_files) == expected_num_shards
1161
assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100
1164
def test_generator_based_builder_download_and_prepare_with_num_proc(tmp_path):
1165
builder = DummyGeneratorBasedBuilderWithShards(cache_dir=tmp_path)
1166
builder.download_and_prepare(num_proc=2)
1167
expected_num_shards = 2
1168
assert builder.info.splits["train"].num_examples == 400
1169
assert builder.info.splits["train"].shard_lengths == [200, 200]
1170
arrow_path = os.path.join(
1172
builder.dataset_name,
1175
f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.arrow",
1177
assert os.path.exists(arrow_path)
1178
ds = builder.as_dataset("train")
1179
assert len(ds) == 400
1180
assert ds.to_dict() == {
1181
"id": [i for _ in range(4) for i in range(100)],
1182
"filepath": [f"data{i}.txt" for i in range(4) for _ in range(100)],
1186
@pytest.mark.parametrize(
1187
"num_proc, expectation", [(None, does_not_raise()), (1, does_not_raise()), (2, pytest.raises(RuntimeError))]
1189
def test_generator_based_builder_download_and_prepare_with_ambiguous_shards(num_proc, expectation, tmp_path):
1190
builder = DummyGeneratorBasedBuilderWithAmbiguousShards(cache_dir=tmp_path)
1192
builder.download_and_prepare(num_proc=num_proc)
1195
def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path):
1196
builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
1197
builder.download_and_prepare(file_format="parquet")
1198
assert builder.info.splits["train"].num_examples == 100
1199
parquet_path = os.path.join(
1200
tmp_path, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.parquet"
1202
assert os.path.exists(parquet_path)
1203
assert pq.ParquetFile(parquet_path) is not None
1206
def test_arrow_based_builder_download_and_prepare_sharded(tmp_path):
1207
builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
1208
with patch("datasets.config.MAX_SHARD_SIZE", 1):
1209
builder.download_and_prepare(file_format="parquet")
1210
expected_num_shards = 10
1211
assert builder.info.splits["train"].num_examples == 100
1212
parquet_path = os.path.join(
1214
builder.dataset_name,
1217
f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.parquet",
1219
assert os.path.exists(parquet_path)
1221
pq.ParquetFile(parquet_path)
1222
for parquet_path in Path(tmp_path).rglob(
1223
f"{builder.dataset_name}-train-*-of-{expected_num_shards:05d}.parquet"
1226
assert len(parquet_files) == expected_num_shards
1227
assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100
1230
def test_arrow_based_builder_download_and_prepare_with_max_shard_size(tmp_path):
1231
builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
1232
builder.download_and_prepare(file_format="parquet", max_shard_size=1)
1233
expected_num_shards = 10
1234
assert builder.info.splits["train"].num_examples == 100
1235
parquet_path = os.path.join(
1237
builder.dataset_name,
1240
f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.parquet",
1242
assert os.path.exists(parquet_path)
1244
pq.ParquetFile(parquet_path)
1245
for parquet_path in Path(tmp_path).rglob(
1246
f"{builder.dataset_name}-train-*-of-{expected_num_shards:05d}.parquet"
1249
assert len(parquet_files) == expected_num_shards
1250
assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100
1253
def test_arrow_based_builder_download_and_prepare_with_num_proc(tmp_path):
1254
builder = DummyArrowBasedBuilderWithShards(cache_dir=tmp_path)
1255
builder.download_and_prepare(num_proc=2)
1256
expected_num_shards = 2
1257
assert builder.info.splits["train"].num_examples == 400
1258
assert builder.info.splits["train"].shard_lengths == [200, 200]
1259
arrow_path = os.path.join(
1261
builder.dataset_name,
1264
f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.arrow",
1266
assert os.path.exists(arrow_path)
1267
ds = builder.as_dataset("train")
1268
assert len(ds) == 400
1269
assert ds.to_dict() == {
1270
"id": [i for _ in range(4) for i in range(100)],
1271
"filepath": [f"data{i}.txt" for i in range(4) for _ in range(100)],
1275
@pytest.mark.parametrize(
1276
"num_proc, expectation", [(None, does_not_raise()), (1, does_not_raise()), (2, pytest.raises(RuntimeError))]
1278
def test_arrow_based_builder_download_and_prepare_with_ambiguous_shards(num_proc, expectation, tmp_path):
1279
builder = DummyArrowBasedBuilderWithAmbiguousShards(cache_dir=tmp_path)
1281
builder.download_and_prepare(num_proc=num_proc)
1285
def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path):
1286
builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner")
1287
builder.download_and_prepare(file_format="parquet")
1288
assert builder.info.splits["train"].num_examples == 100
1289
parquet_path = os.path.join(
1290
tmp_path, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.parquet"
1292
assert os.path.exists(parquet_path)
1293
assert pq.ParquetFile(parquet_path) is not None