datasets

Форк
0
/
test_builder.py 
1293 строки · 57.1 Кб
1
import importlib
2
import os
3
import tempfile
4
import types
5
from contextlib import nullcontext as does_not_raise
6
from multiprocessing import Process
7
from pathlib import Path
8
from unittest import TestCase
9
from unittest.mock import patch
10

11
import numpy as np
12
import pyarrow as pa
13
import pyarrow.parquet as pq
14
import pytest
15
from multiprocess.pool import Pool
16

17
from datasets.arrow_dataset import Dataset
18
from datasets.arrow_reader import DatasetNotOnHfGcsError
19
from datasets.arrow_writer import ArrowWriter
20
from datasets.builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
21
from datasets.dataset_dict import DatasetDict, IterableDatasetDict
22
from datasets.download.download_manager import DownloadMode
23
from datasets.features import Features, Value
24
from datasets.info import DatasetInfo, PostProcessedInfo
25
from datasets.iterable_dataset import IterableDataset
26
from datasets.load import configure_builder_class
27
from datasets.splits import Split, SplitDict, SplitGenerator, SplitInfo
28
from datasets.streaming import xjoin
29
from datasets.utils.file_utils import is_local_path
30
from datasets.utils.info_utils import VerificationMode
31
from datasets.utils.logging import INFO, get_logger
32

33
from .utils import (
34
    assert_arrow_memory_doesnt_increase,
35
    assert_arrow_memory_increases,
36
    require_beam,
37
    require_faiss,
38
    set_current_working_directory_to_temp_dir,
39
)
40

41

42
class DummyBuilder(DatasetBuilder):
43
    def _info(self):
44
        return DatasetInfo(features=Features({"text": Value("string")}))
45

46
    def _split_generators(self, dl_manager):
47
        return [SplitGenerator(name=Split.TRAIN)]
48

49
    def _prepare_split(self, split_generator, **kwargs):
50
        fname = f"{self.dataset_name}-{split_generator.name}.arrow"
51
        with ArrowWriter(features=self.info.features, path=os.path.join(self._output_dir, fname)) as writer:
52
            writer.write_batch({"text": ["foo"] * 100})
53
            num_examples, num_bytes = writer.finalize()
54
        split_generator.split_info.num_examples = num_examples
55
        split_generator.split_info.num_bytes = num_bytes
56

57

58
class DummyGeneratorBasedBuilder(GeneratorBasedBuilder):
59
    def _info(self):
60
        return DatasetInfo(features=Features({"text": Value("string")}))
61

62
    def _split_generators(self, dl_manager):
63
        return [SplitGenerator(name=Split.TRAIN)]
64

65
    def _generate_examples(self):
66
        for i in range(100):
67
            yield i, {"text": "foo"}
68

69

70
class DummyArrowBasedBuilder(ArrowBasedBuilder):
71
    def _info(self):
72
        return DatasetInfo(features=Features({"text": Value("string")}))
73

74
    def _split_generators(self, dl_manager):
75
        return [SplitGenerator(name=Split.TRAIN)]
76

77
    def _generate_tables(self):
78
        for i in range(10):
79
            yield i, pa.table({"text": ["foo"] * 10})
80

81

82
class DummyBeamBasedBuilder(BeamBasedBuilder):
83
    def _info(self):
84
        return DatasetInfo(features=Features({"text": Value("string")}))
85

86
    def _split_generators(self, dl_manager):
87
        return [SplitGenerator(name=Split.TRAIN)]
88

89
    def _build_pcollection(self, pipeline):
90
        import apache_beam as beam
91

92
        def _process(item):
93
            for i in range(10):
94
                yield f"{i}_{item}", {"text": "foo"}
95

96
        return pipeline | "Initialize" >> beam.Create(range(10)) | "Extract content" >> beam.FlatMap(_process)
97

98

99
class DummyGeneratorBasedBuilderWithIntegers(GeneratorBasedBuilder):
100
    def _info(self):
101
        return DatasetInfo(features=Features({"id": Value("int8")}))
102

103
    def _split_generators(self, dl_manager):
104
        return [SplitGenerator(name=Split.TRAIN)]
105

106
    def _generate_examples(self):
107
        for i in range(100):
108
            yield i, {"id": i}
109

110

111
class DummyGeneratorBasedBuilderConfig(BuilderConfig):
112
    def __init__(self, content="foo", times=2, *args, **kwargs):
113
        super().__init__(*args, **kwargs)
114
        self.content = content
115
        self.times = times
116

117

118
class DummyGeneratorBasedBuilderWithConfig(GeneratorBasedBuilder):
119
    BUILDER_CONFIG_CLASS = DummyGeneratorBasedBuilderConfig
120

121
    def _info(self):
122
        return DatasetInfo(features=Features({"text": Value("string")}))
123

124
    def _split_generators(self, dl_manager):
125
        return [SplitGenerator(name=Split.TRAIN)]
126

127
    def _generate_examples(self):
128
        for i in range(100):
129
            yield i, {"text": self.config.content * self.config.times}
130

131

132
class DummyBuilderWithMultipleConfigs(DummyBuilder):
133
    BUILDER_CONFIGS = [
134
        DummyGeneratorBasedBuilderConfig(name="a"),
135
        DummyGeneratorBasedBuilderConfig(name="b"),
136
    ]
137

138

139
class DummyBuilderWithDefaultConfig(DummyBuilderWithMultipleConfigs):
140
    DEFAULT_CONFIG_NAME = "a"
141

142

143
class DummyBuilderWithDownload(DummyBuilder):
144
    def __init__(self, *args, rel_path=None, abs_path=None, **kwargs):
145
        super().__init__(*args, **kwargs)
146
        self._rel_path = rel_path
147
        self._abs_path = abs_path
148

149
    def _split_generators(self, dl_manager):
150
        if self._rel_path is not None:
151
            assert os.path.exists(dl_manager.download(self._rel_path)), "dl_manager must support relative paths"
152
        if self._abs_path is not None:
153
            assert os.path.exists(dl_manager.download(self._abs_path)), "dl_manager must support absolute paths"
154
        return [SplitGenerator(name=Split.TRAIN)]
155

156

157
class DummyBuilderWithManualDownload(DummyBuilderWithMultipleConfigs):
158
    @property
159
    def manual_download_instructions(self):
160
        return "To use the dataset you have to download some stuff manually and pass the data path to data_dir"
161

162
    def _split_generators(self, dl_manager):
163
        if not os.path.exists(self.config.data_dir):
164
            raise FileNotFoundError(f"data_dir {self.config.data_dir} doesn't exist.")
165
        return [SplitGenerator(name=Split.TRAIN)]
166

167

168
class DummyArrowBasedBuilderWithShards(ArrowBasedBuilder):
169
    def _info(self):
170
        return DatasetInfo(features=Features({"id": Value("int8"), "filepath": Value("string")}))
171

172
    def _split_generators(self, dl_manager):
173
        return [SplitGenerator(name=Split.TRAIN, gen_kwargs={"filepaths": [f"data{i}.txt" for i in range(4)]})]
174

175
    def _generate_tables(self, filepaths):
176
        idx = 0
177
        for filepath in filepaths:
178
            for i in range(10):
179
                yield idx, pa.table({"id": range(10 * i, 10 * (i + 1)), "filepath": [filepath] * 10})
180
                idx += 1
181

182

183
class DummyGeneratorBasedBuilderWithShards(GeneratorBasedBuilder):
184
    def _info(self):
185
        return DatasetInfo(features=Features({"id": Value("int8"), "filepath": Value("string")}))
186

187
    def _split_generators(self, dl_manager):
188
        return [SplitGenerator(name=Split.TRAIN, gen_kwargs={"filepaths": [f"data{i}.txt" for i in range(4)]})]
189

190
    def _generate_examples(self, filepaths):
191
        idx = 0
192
        for filepath in filepaths:
193
            for i in range(100):
194
                yield idx, {"id": i, "filepath": filepath}
195
                idx += 1
196

197

198
class DummyArrowBasedBuilderWithAmbiguousShards(ArrowBasedBuilder):
199
    def _info(self):
200
        return DatasetInfo(features=Features({"id": Value("int8"), "filepath": Value("string")}))
201

202
    def _split_generators(self, dl_manager):
203
        return [
204
            SplitGenerator(
205
                name=Split.TRAIN,
206
                gen_kwargs={
207
                    "filepaths": [f"data{i}.txt" for i in range(4)],
208
                    "dummy_kwarg_with_different_length": [f"dummy_data{i}.txt" for i in range(3)],
209
                },
210
            )
211
        ]
212

213
    def _generate_tables(self, filepaths, dummy_kwarg_with_different_length):
214
        idx = 0
215
        for filepath in filepaths:
216
            for i in range(10):
217
                yield idx, pa.table({"id": range(10 * i, 10 * (i + 1)), "filepath": [filepath] * 10})
218
                idx += 1
219

220

221
class DummyGeneratorBasedBuilderWithAmbiguousShards(GeneratorBasedBuilder):
222
    def _info(self):
223
        return DatasetInfo(features=Features({"id": Value("int8"), "filepath": Value("string")}))
224

225
    def _split_generators(self, dl_manager):
226
        return [
227
            SplitGenerator(
228
                name=Split.TRAIN,
229
                gen_kwargs={
230
                    "filepaths": [f"data{i}.txt" for i in range(4)],
231
                    "dummy_kwarg_with_different_length": [f"dummy_data{i}.txt" for i in range(3)],
232
                },
233
            )
234
        ]
235

236
    def _generate_examples(self, filepaths, dummy_kwarg_with_different_length):
237
        idx = 0
238
        for filepath in filepaths:
239
            for i in range(100):
240
                yield idx, {"id": i, "filepath": filepath}
241
                idx += 1
242

243

244
def _run_concurrent_download_and_prepare(tmp_dir):
245
    builder = DummyBuilder(cache_dir=tmp_dir)
246
    builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
247
    return builder
248

249

250
def check_streaming(builder):
251
    builders_module = importlib.import_module(builder.__module__)
252
    assert builders_module._patched_for_streaming
253
    assert builders_module.os.path.join is xjoin
254

255

256
class BuilderTest(TestCase):
257
    def test_download_and_prepare(self):
258
        with tempfile.TemporaryDirectory() as tmp_dir:
259
            builder = DummyBuilder(cache_dir=tmp_dir)
260
            builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
261
            self.assertTrue(
262
                os.path.exists(
263
                    os.path.join(
264
                        tmp_dir, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.arrow"
265
                    )
266
                )
267
            )
268
            self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
269
            self.assertEqual(builder.info.splits["train"].num_examples, 100)
270
            self.assertTrue(
271
                os.path.exists(os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
272
            )
273

274
    def test_download_and_prepare_checksum_computation(self):
275
        with tempfile.TemporaryDirectory() as tmp_dir:
276
            builder_no_verification = DummyBuilder(cache_dir=tmp_dir)
277
            builder_no_verification.download_and_prepare(
278
                try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD
279
            )
280
            self.assertTrue(
281
                all(v["checksum"] is not None for _, v in builder_no_verification.info.download_checksums.items())
282
            )
283
            builder_with_verification = DummyBuilder(cache_dir=tmp_dir)
284
            builder_with_verification.download_and_prepare(
285
                try_from_hf_gcs=False,
286
                download_mode=DownloadMode.FORCE_REDOWNLOAD,
287
                verification_mode=VerificationMode.ALL_CHECKS,
288
            )
289
            self.assertTrue(
290
                all(v["checksum"] is None for _, v in builder_with_verification.info.download_checksums.items())
291
            )
292

293
    def test_concurrent_download_and_prepare(self):
294
        with tempfile.TemporaryDirectory() as tmp_dir:
295
            processes = 2
296
            with Pool(processes=processes) as pool:
297
                jobs = [
298
                    pool.apply_async(_run_concurrent_download_and_prepare, kwds={"tmp_dir": tmp_dir})
299
                    for _ in range(processes)
300
                ]
301
                builders = [job.get() for job in jobs]
302
                for builder in builders:
303
                    self.assertTrue(
304
                        os.path.exists(
305
                            os.path.join(
306
                                tmp_dir,
307
                                builder.dataset_name,
308
                                "default",
309
                                "0.0.0",
310
                                f"{builder.dataset_name}-train.arrow",
311
                            )
312
                        )
313
                    )
314
                    self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
315
                    self.assertEqual(builder.info.splits["train"].num_examples, 100)
316
                    self.assertTrue(
317
                        os.path.exists(
318
                            os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json")
319
                        )
320
                    )
321

322
    def test_download_and_prepare_with_base_path(self):
323
        with tempfile.TemporaryDirectory() as tmp_dir:
324
            rel_path = "dummy1.data"
325
            abs_path = os.path.join(tmp_dir, "dummy2.data")
326
            # test relative path is missing
327
            builder = DummyBuilderWithDownload(cache_dir=tmp_dir, rel_path=rel_path)
328
            with self.assertRaises(FileNotFoundError):
329
                builder.download_and_prepare(
330
                    try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD, base_path=tmp_dir
331
                )
332
            # test absolute path is missing
333
            builder = DummyBuilderWithDownload(cache_dir=tmp_dir, abs_path=abs_path)
334
            with self.assertRaises(FileNotFoundError):
335
                builder.download_and_prepare(
336
                    try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD, base_path=tmp_dir
337
                )
338
            # test that they are both properly loaded when they exist
339
            open(os.path.join(tmp_dir, rel_path), "w")
340
            open(abs_path, "w")
341
            builder = DummyBuilderWithDownload(cache_dir=tmp_dir, rel_path=rel_path, abs_path=abs_path)
342
            builder.download_and_prepare(
343
                try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD, base_path=tmp_dir
344
            )
345
            self.assertTrue(
346
                os.path.exists(
347
                    os.path.join(
348
                        tmp_dir,
349
                        builder.dataset_name,
350
                        "default",
351
                        "0.0.0",
352
                        f"{builder.dataset_name}-train.arrow",
353
                    )
354
                )
355
            )
356

357
    def test_as_dataset_with_post_process(self):
358
        def _post_process(self, dataset, resources_paths):
359
            def char_tokenize(example):
360
                return {"tokens": list(example["text"])}
361

362
            return dataset.map(char_tokenize, cache_file_name=resources_paths["tokenized_dataset"])
363

364
        def _post_processing_resources(self, split):
365
            return {"tokenized_dataset": f"tokenized_dataset-{split}.arrow"}
366

367
        with tempfile.TemporaryDirectory() as tmp_dir:
368
            builder = DummyBuilder(cache_dir=tmp_dir)
369
            builder.info.post_processed = PostProcessedInfo(
370
                features=Features({"text": Value("string"), "tokens": [Value("string")]})
371
            )
372
            builder._post_process = types.MethodType(_post_process, builder)
373
            builder._post_processing_resources = types.MethodType(_post_processing_resources, builder)
374
            os.makedirs(builder.cache_dir)
375

376
            builder.info.splits = SplitDict()
377
            builder.info.splits.add(SplitInfo("train", num_examples=10))
378
            builder.info.splits.add(SplitInfo("test", num_examples=10))
379

380
            for split in builder.info.splits:
381
                with ArrowWriter(
382
                    path=os.path.join(builder.cache_dir, f"{builder.dataset_name}-{split}.arrow"),
383
                    features=Features({"text": Value("string")}),
384
                ) as writer:
385
                    writer.write_batch({"text": ["foo"] * 10})
386
                    writer.finalize()
387

388
                with ArrowWriter(
389
                    path=os.path.join(builder.cache_dir, f"tokenized_dataset-{split}.arrow"),
390
                    features=Features({"text": Value("string"), "tokens": [Value("string")]}),
391
                ) as writer:
392
                    writer.write_batch({"text": ["foo"] * 10, "tokens": [list("foo")] * 10})
393
                    writer.finalize()
394

395
            dsets = builder.as_dataset()
396
            self.assertIsInstance(dsets, DatasetDict)
397
            self.assertListEqual(list(dsets.keys()), ["train", "test"])
398
            self.assertEqual(len(dsets["train"]), 10)
399
            self.assertEqual(len(dsets["test"]), 10)
400
            self.assertDictEqual(
401
                dsets["train"].features, Features({"text": Value("string"), "tokens": [Value("string")]})
402
            )
403
            self.assertDictEqual(
404
                dsets["test"].features, Features({"text": Value("string"), "tokens": [Value("string")]})
405
            )
406
            self.assertListEqual(dsets["train"].column_names, ["text", "tokens"])
407
            self.assertListEqual(dsets["test"].column_names, ["text", "tokens"])
408
            del dsets
409

410
            dset = builder.as_dataset("train")
411
            self.assertIsInstance(dset, Dataset)
412
            self.assertEqual(dset.split, "train")
413
            self.assertEqual(len(dset), 10)
414
            self.assertDictEqual(dset.features, Features({"text": Value("string"), "tokens": [Value("string")]}))
415
            self.assertListEqual(dset.column_names, ["text", "tokens"])
416
            self.assertGreater(builder.info.post_processing_size, 0)
417
            self.assertGreater(
418
                builder.info.post_processed.resources_checksums["train"]["tokenized_dataset"]["num_bytes"], 0
419
            )
420
            del dset
421

422
            dset = builder.as_dataset("train+test[:30%]")
423
            self.assertIsInstance(dset, Dataset)
424
            self.assertEqual(dset.split, "train+test[:30%]")
425
            self.assertEqual(len(dset), 13)
426
            self.assertDictEqual(dset.features, Features({"text": Value("string"), "tokens": [Value("string")]}))
427
            self.assertListEqual(dset.column_names, ["text", "tokens"])
428
            del dset
429

430
            dset = builder.as_dataset("all")
431
            self.assertIsInstance(dset, Dataset)
432
            self.assertEqual(dset.split, "train+test")
433
            self.assertEqual(len(dset), 20)
434
            self.assertDictEqual(dset.features, Features({"text": Value("string"), "tokens": [Value("string")]}))
435
            self.assertListEqual(dset.column_names, ["text", "tokens"])
436
            del dset
437

438
        def _post_process(self, dataset, resources_paths):
439
            return dataset.select([0, 1], keep_in_memory=True)
440

441
        with tempfile.TemporaryDirectory() as tmp_dir:
442
            builder = DummyBuilder(cache_dir=tmp_dir)
443
            builder._post_process = types.MethodType(_post_process, builder)
444
            os.makedirs(builder.cache_dir)
445

446
            builder.info.splits = SplitDict()
447
            builder.info.splits.add(SplitInfo("train", num_examples=10))
448
            builder.info.splits.add(SplitInfo("test", num_examples=10))
449

450
            for split in builder.info.splits:
451
                with ArrowWriter(
452
                    path=os.path.join(builder.cache_dir, f"{builder.dataset_name}-{split}.arrow"),
453
                    features=Features({"text": Value("string")}),
454
                ) as writer:
455
                    writer.write_batch({"text": ["foo"] * 10})
456
                    writer.finalize()
457

458
                with ArrowWriter(
459
                    path=os.path.join(builder.cache_dir, f"small_dataset-{split}.arrow"),
460
                    features=Features({"text": Value("string")}),
461
                ) as writer:
462
                    writer.write_batch({"text": ["foo"] * 2})
463
                    writer.finalize()
464

465
            dsets = builder.as_dataset()
466
            self.assertIsInstance(dsets, DatasetDict)
467
            self.assertListEqual(list(dsets.keys()), ["train", "test"])
468
            self.assertEqual(len(dsets["train"]), 2)
469
            self.assertEqual(len(dsets["test"]), 2)
470
            self.assertDictEqual(dsets["train"].features, Features({"text": Value("string")}))
471
            self.assertDictEqual(dsets["test"].features, Features({"text": Value("string")}))
472
            self.assertListEqual(dsets["train"].column_names, ["text"])
473
            self.assertListEqual(dsets["test"].column_names, ["text"])
474
            del dsets
475

476
            dset = builder.as_dataset("train")
477
            self.assertIsInstance(dset, Dataset)
478
            self.assertEqual(dset.split, "train")
479
            self.assertEqual(len(dset), 2)
480
            self.assertDictEqual(dset.features, Features({"text": Value("string")}))
481
            self.assertListEqual(dset.column_names, ["text"])
482
            del dset
483

484
            dset = builder.as_dataset("train+test[:30%]")
485
            self.assertIsInstance(dset, Dataset)
486
            self.assertEqual(dset.split, "train+test[:30%]")
487
            self.assertEqual(len(dset), 2)
488
            self.assertDictEqual(dset.features, Features({"text": Value("string")}))
489
            self.assertListEqual(dset.column_names, ["text"])
490
            del dset
491

492
    @require_faiss
493
    def test_as_dataset_with_post_process_with_index(self):
494
        def _post_process(self, dataset, resources_paths):
495
            if os.path.exists(resources_paths["index"]):
496
                dataset.load_faiss_index("my_index", resources_paths["index"])
497
                return dataset
498
            else:
499
                dataset.add_faiss_index_from_external_arrays(
500
                    external_arrays=np.ones((len(dataset), 8)), string_factory="Flat", index_name="my_index"
501
                )
502
                dataset.save_faiss_index("my_index", resources_paths["index"])
503
                return dataset
504

505
        def _post_processing_resources(self, split):
506
            return {"index": f"Flat-{split}.faiss"}
507

508
        with tempfile.TemporaryDirectory() as tmp_dir:
509
            builder = DummyBuilder(cache_dir=tmp_dir)
510
            builder._post_process = types.MethodType(_post_process, builder)
511
            builder._post_processing_resources = types.MethodType(_post_processing_resources, builder)
512
            os.makedirs(builder.cache_dir)
513

514
            builder.info.splits = SplitDict()
515
            builder.info.splits.add(SplitInfo("train", num_examples=10))
516
            builder.info.splits.add(SplitInfo("test", num_examples=10))
517

518
            for split in builder.info.splits:
519
                with ArrowWriter(
520
                    path=os.path.join(builder.cache_dir, f"{builder.dataset_name}-{split}.arrow"),
521
                    features=Features({"text": Value("string")}),
522
                ) as writer:
523
                    writer.write_batch({"text": ["foo"] * 10})
524
                    writer.finalize()
525

526
                with ArrowWriter(
527
                    path=os.path.join(builder.cache_dir, f"small_dataset-{split}.arrow"),
528
                    features=Features({"text": Value("string")}),
529
                ) as writer:
530
                    writer.write_batch({"text": ["foo"] * 2})
531
                    writer.finalize()
532

533
            dsets = builder.as_dataset()
534
            self.assertIsInstance(dsets, DatasetDict)
535
            self.assertListEqual(list(dsets.keys()), ["train", "test"])
536
            self.assertEqual(len(dsets["train"]), 10)
537
            self.assertEqual(len(dsets["test"]), 10)
538
            self.assertDictEqual(dsets["train"].features, Features({"text": Value("string")}))
539
            self.assertDictEqual(dsets["test"].features, Features({"text": Value("string")}))
540
            self.assertListEqual(dsets["train"].column_names, ["text"])
541
            self.assertListEqual(dsets["test"].column_names, ["text"])
542
            self.assertListEqual(dsets["train"].list_indexes(), ["my_index"])
543
            self.assertListEqual(dsets["test"].list_indexes(), ["my_index"])
544
            self.assertGreater(builder.info.post_processing_size, 0)
545
            self.assertGreater(builder.info.post_processed.resources_checksums["train"]["index"]["num_bytes"], 0)
546
            del dsets
547

548
            dset = builder.as_dataset("train")
549
            self.assertIsInstance(dset, Dataset)
550
            self.assertEqual(dset.split, "train")
551
            self.assertEqual(len(dset), 10)
552
            self.assertDictEqual(dset.features, Features({"text": Value("string")}))
553
            self.assertListEqual(dset.column_names, ["text"])
554
            self.assertListEqual(dset.list_indexes(), ["my_index"])
555
            del dset
556

557
            dset = builder.as_dataset("train+test[:30%]")
558
            self.assertIsInstance(dset, Dataset)
559
            self.assertEqual(dset.split, "train+test[:30%]")
560
            self.assertEqual(len(dset), 13)
561
            self.assertDictEqual(dset.features, Features({"text": Value("string")}))
562
            self.assertListEqual(dset.column_names, ["text"])
563
            self.assertListEqual(dset.list_indexes(), ["my_index"])
564
            del dset
565

566
    def test_download_and_prepare_with_post_process(self):
567
        def _post_process(self, dataset, resources_paths):
568
            def char_tokenize(example):
569
                return {"tokens": list(example["text"])}
570

571
            return dataset.map(char_tokenize, cache_file_name=resources_paths["tokenized_dataset"])
572

573
        def _post_processing_resources(self, split):
574
            return {"tokenized_dataset": f"tokenized_dataset-{split}.arrow"}
575

576
        with tempfile.TemporaryDirectory() as tmp_dir:
577
            builder = DummyBuilder(cache_dir=tmp_dir)
578
            builder.info.post_processed = PostProcessedInfo(
579
                features=Features({"text": Value("string"), "tokens": [Value("string")]})
580
            )
581
            builder._post_process = types.MethodType(_post_process, builder)
582
            builder._post_processing_resources = types.MethodType(_post_processing_resources, builder)
583
            builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
584
            self.assertTrue(
585
                os.path.exists(
586
                    os.path.join(
587
                        tmp_dir, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.arrow"
588
                    )
589
                )
590
            )
591
            self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
592
            self.assertDictEqual(
593
                builder.info.post_processed.features,
594
                Features({"text": Value("string"), "tokens": [Value("string")]}),
595
            )
596
            self.assertEqual(builder.info.splits["train"].num_examples, 100)
597
            self.assertTrue(
598
                os.path.exists(os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
599
            )
600

601
        def _post_process(self, dataset, resources_paths):
602
            return dataset.select([0, 1], keep_in_memory=True)
603

604
        with tempfile.TemporaryDirectory() as tmp_dir:
605
            builder = DummyBuilder(cache_dir=tmp_dir)
606
            builder._post_process = types.MethodType(_post_process, builder)
607
            builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
608
            self.assertTrue(
609
                os.path.exists(
610
                    os.path.join(
611
                        tmp_dir, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.arrow"
612
                    )
613
                )
614
            )
615
            self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
616
            self.assertIsNone(builder.info.post_processed)
617
            self.assertEqual(builder.info.splits["train"].num_examples, 100)
618
            self.assertTrue(
619
                os.path.exists(os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
620
            )
621

622
        def _post_process(self, dataset, resources_paths):
623
            if os.path.exists(resources_paths["index"]):
624
                dataset.load_faiss_index("my_index", resources_paths["index"])
625
                return dataset
626
            else:
627
                dataset = dataset.add_faiss_index_from_external_arrays(
628
                    external_arrays=np.ones((len(dataset), 8)), string_factory="Flat", index_name="my_index"
629
                )
630
                dataset.save_faiss_index("my_index", resources_paths["index"])
631
                return dataset
632

633
        def _post_processing_resources(self, split):
634
            return {"index": f"Flat-{split}.faiss"}
635

636
        with tempfile.TemporaryDirectory() as tmp_dir:
637
            builder = DummyBuilder(cache_dir=tmp_dir)
638
            builder._post_process = types.MethodType(_post_process, builder)
639
            builder._post_processing_resources = types.MethodType(_post_processing_resources, builder)
640
            builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
641
            self.assertTrue(
642
                os.path.exists(
643
                    os.path.join(
644
                        tmp_dir, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.arrow"
645
                    )
646
                )
647
            )
648
            self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
649
            self.assertIsNone(builder.info.post_processed)
650
            self.assertEqual(builder.info.splits["train"].num_examples, 100)
651
            self.assertTrue(
652
                os.path.exists(os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
653
            )
654

655
    def test_error_download_and_prepare(self):
656
        def _prepare_split(self, split_generator, **kwargs):
657
            raise ValueError()
658

659
        with tempfile.TemporaryDirectory() as tmp_dir:
660
            builder = DummyBuilder(cache_dir=tmp_dir)
661
            builder._prepare_split = types.MethodType(_prepare_split, builder)
662
            self.assertRaises(
663
                ValueError,
664
                builder.download_and_prepare,
665
                try_from_hf_gcs=False,
666
                download_mode=DownloadMode.FORCE_REDOWNLOAD,
667
            )
668
            self.assertRaises(FileNotFoundError, builder.as_dataset)
669

670
    def test_generator_based_download_and_prepare(self):
671
        with tempfile.TemporaryDirectory() as tmp_dir:
672
            builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir)
673
            builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
674
            self.assertTrue(
675
                os.path.exists(
676
                    os.path.join(
677
                        tmp_dir,
678
                        builder.dataset_name,
679
                        "default",
680
                        "0.0.0",
681
                        f"{builder.dataset_name}-train.arrow",
682
                    )
683
                )
684
            )
685
            self.assertDictEqual(builder.info.features, Features({"text": Value("string")}))
686
            self.assertEqual(builder.info.splits["train"].num_examples, 100)
687
            self.assertTrue(
688
                os.path.exists(os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
689
            )
690

691
        # Test that duplicated keys are ignored if verification_mode is "no_checks"
692
        with tempfile.TemporaryDirectory() as tmp_dir:
693
            builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir)
694
            with patch("datasets.builder.ArrowWriter", side_effect=ArrowWriter) as mock_arrow_writer:
695
                builder.download_and_prepare(
696
                    download_mode=DownloadMode.FORCE_REDOWNLOAD, verification_mode=VerificationMode.NO_CHECKS
697
                )
698
                mock_arrow_writer.assert_called_once()
699
                args, kwargs = mock_arrow_writer.call_args_list[0]
700
                self.assertFalse(kwargs["check_duplicates"])
701

702
                mock_arrow_writer.reset_mock()
703

704
                builder.download_and_prepare(
705
                    download_mode=DownloadMode.FORCE_REDOWNLOAD, verification_mode=VerificationMode.BASIC_CHECKS
706
                )
707
                mock_arrow_writer.assert_called_once()
708
                args, kwargs = mock_arrow_writer.call_args_list[0]
709
                self.assertTrue(kwargs["check_duplicates"])
710

711
    def test_cache_dir_no_args(self):
712
        with tempfile.TemporaryDirectory() as tmp_dir:
713
            builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_dir=None, data_files=None)
714
            relative_cache_dir_parts = Path(builder._relative_data_dir()).parts
715
            self.assertTupleEqual(relative_cache_dir_parts, (builder.dataset_name, "default", "0.0.0"))
716

717
    def test_cache_dir_for_data_files(self):
718
        with tempfile.TemporaryDirectory() as tmp_dir:
719
            dummy_data1 = os.path.join(tmp_dir, "dummy_data1.txt")
720
            with open(dummy_data1, "w", encoding="utf-8") as f:
721
                f.writelines("foo bar")
722
            dummy_data2 = os.path.join(tmp_dir, "dummy_data2.txt")
723
            with open(dummy_data2, "w", encoding="utf-8") as f:
724
                f.writelines("foo bar\n")
725

726
            builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=dummy_data1)
727
            other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=dummy_data1)
728
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
729
            other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data1])
730
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
731
            other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files={"train": dummy_data1})
732
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
733
            other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files={Split.TRAIN: dummy_data1})
734
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
735
            other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files={"train": [dummy_data1]})
736
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
737
            other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files={"test": dummy_data1})
738
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
739
            other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=dummy_data2)
740
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
741
            other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data2])
742
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
743
            other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data1, dummy_data2])
744
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
745

746
            builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data1, dummy_data2])
747
            other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data1, dummy_data2])
748
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
749
            other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_files=[dummy_data2, dummy_data1])
750
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
751

752
            builder = DummyGeneratorBasedBuilder(
753
                cache_dir=tmp_dir, data_files={"train": dummy_data1, "test": dummy_data2}
754
            )
755
            other_builder = DummyGeneratorBasedBuilder(
756
                cache_dir=tmp_dir, data_files={"train": dummy_data1, "test": dummy_data2}
757
            )
758
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
759
            other_builder = DummyGeneratorBasedBuilder(
760
                cache_dir=tmp_dir, data_files={"train": [dummy_data1], "test": dummy_data2}
761
            )
762
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
763
            other_builder = DummyGeneratorBasedBuilder(
764
                cache_dir=tmp_dir, data_files={"train": dummy_data1, "validation": dummy_data2}
765
            )
766
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
767
            other_builder = DummyGeneratorBasedBuilder(
768
                cache_dir=tmp_dir,
769
                data_files={"train": [dummy_data1, dummy_data2], "test": dummy_data2},
770
            )
771
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
772

773
    def test_cache_dir_for_features(self):
774
        with tempfile.TemporaryDirectory() as tmp_dir:
775
            f1 = Features({"id": Value("int8")})
776
            f2 = Features({"id": Value("int32")})
777
            builder = DummyGeneratorBasedBuilderWithIntegers(cache_dir=tmp_dir, features=f1)
778
            other_builder = DummyGeneratorBasedBuilderWithIntegers(cache_dir=tmp_dir, features=f1)
779
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
780
            other_builder = DummyGeneratorBasedBuilderWithIntegers(cache_dir=tmp_dir, features=f2)
781
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
782

783
    def test_cache_dir_for_config_kwargs(self):
784
        with tempfile.TemporaryDirectory() as tmp_dir:
785
            # create config on the fly
786
            builder = DummyGeneratorBasedBuilderWithConfig(cache_dir=tmp_dir, content="foo", times=2)
787
            other_builder = DummyGeneratorBasedBuilderWithConfig(cache_dir=tmp_dir, times=2, content="foo")
788
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
789
            self.assertIn("content=foo", builder.cache_dir)
790
            self.assertIn("times=2", builder.cache_dir)
791
            other_builder = DummyGeneratorBasedBuilderWithConfig(cache_dir=tmp_dir, content="bar", times=2)
792
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
793
            other_builder = DummyGeneratorBasedBuilderWithConfig(cache_dir=tmp_dir, content="foo")
794
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
795

796
        with tempfile.TemporaryDirectory() as tmp_dir:
797
            # overwrite an existing config
798
            builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="a", content="foo", times=2)
799
            other_builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="a", times=2, content="foo")
800
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
801
            self.assertIn("content=foo", builder.cache_dir)
802
            self.assertIn("times=2", builder.cache_dir)
803
            other_builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="a", content="bar", times=2)
804
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
805
            other_builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="a", content="foo")
806
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
807

808
    def test_config_names(self):
809
        with tempfile.TemporaryDirectory() as tmp_dir:
810
            with self.assertRaises(ValueError) as error_context:
811
                DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, data_files=None, data_dir=None)
812
            self.assertIn("Please pick one among the available configs", str(error_context.exception))
813

814
            builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="a")
815
            self.assertEqual(builder.config.name, "a")
816

817
            builder = DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir, config_name="b")
818
            self.assertEqual(builder.config.name, "b")
819

820
            with self.assertRaises(ValueError):
821
                DummyBuilderWithMultipleConfigs(cache_dir=tmp_dir)
822

823
            builder = DummyBuilderWithDefaultConfig(cache_dir=tmp_dir)
824
            self.assertEqual(builder.config.name, "a")
825

826
    def test_cache_dir_for_data_dir(self):
827
        with tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as data_dir:
828
            builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=data_dir)
829
            other_builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=data_dir)
830
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
831
            other_builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=tmp_dir)
832
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
833

834
    def test_cache_dir_for_configured_builder(self):
835
        with tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as data_dir:
836
            builder_cls = configure_builder_class(
837
                DummyBuilderWithManualDownload,
838
                builder_configs=[BuilderConfig(data_dir=data_dir)],
839
                default_config_name=None,
840
                dataset_name="dummy",
841
            )
842
            builder = builder_cls(cache_dir=tmp_dir, hash="abc")
843
            other_builder = builder_cls(cache_dir=tmp_dir, hash="abc")
844
            self.assertEqual(builder.cache_dir, other_builder.cache_dir)
845
            other_builder = builder_cls(cache_dir=tmp_dir, hash="def")
846
            self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
847

848

849
def test_arrow_based_download_and_prepare(tmp_path):
850
    builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
851
    builder.download_and_prepare()
852
    assert os.path.exists(
853
        os.path.join(
854
            tmp_path,
855
            builder.dataset_name,
856
            "default",
857
            "0.0.0",
858
            f"{builder.dataset_name}-train.arrow",
859
        )
860
    )
861
    assert builder.info.features, Features({"text": Value("string")})
862
    assert builder.info.splits["train"].num_examples == 100
863
    assert os.path.exists(os.path.join(tmp_path, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
864

865

866
@require_beam
867
def test_beam_based_download_and_prepare(tmp_path):
868
    builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner")
869
    builder.download_and_prepare()
870
    assert os.path.exists(
871
        os.path.join(
872
            tmp_path,
873
            builder.dataset_name,
874
            "default",
875
            "0.0.0",
876
            f"{builder.dataset_name}-train.arrow",
877
        )
878
    )
879
    assert builder.info.features, Features({"text": Value("string")})
880
    assert builder.info.splits["train"].num_examples == 100
881
    assert os.path.exists(os.path.join(tmp_path, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
882

883

884
@require_beam
885
def test_beam_based_as_dataset(tmp_path):
886
    builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner")
887
    builder.download_and_prepare()
888
    dataset = builder.as_dataset()
889
    assert dataset
890
    assert isinstance(dataset["train"], Dataset)
891
    assert len(dataset["train"]) > 0
892

893

894
@pytest.mark.parametrize(
895
    "split, expected_dataset_class, expected_dataset_length",
896
    [
897
        (None, DatasetDict, 10),
898
        ("train", Dataset, 10),
899
        ("train+test[:30%]", Dataset, 13),
900
    ],
901
)
902
@pytest.mark.parametrize("in_memory", [False, True])
903
def test_builder_as_dataset(split, expected_dataset_class, expected_dataset_length, in_memory, tmp_path):
904
    cache_dir = str(tmp_path)
905
    builder = DummyBuilder(cache_dir=cache_dir)
906
    os.makedirs(builder.cache_dir)
907

908
    builder.info.splits = SplitDict()
909
    builder.info.splits.add(SplitInfo("train", num_examples=10))
910
    builder.info.splits.add(SplitInfo("test", num_examples=10))
911

912
    for info_split in builder.info.splits:
913
        with ArrowWriter(
914
            path=os.path.join(builder.cache_dir, f"{builder.dataset_name}-{info_split}.arrow"),
915
            features=Features({"text": Value("string")}),
916
        ) as writer:
917
            writer.write_batch({"text": ["foo"] * 10})
918
            writer.finalize()
919

920
    with assert_arrow_memory_increases() if in_memory else assert_arrow_memory_doesnt_increase():
921
        dataset = builder.as_dataset(split=split, in_memory=in_memory)
922
    assert isinstance(dataset, expected_dataset_class)
923
    if isinstance(dataset, DatasetDict):
924
        assert list(dataset.keys()) == ["train", "test"]
925
        datasets = dataset.values()
926
        expected_splits = ["train", "test"]
927
    elif isinstance(dataset, Dataset):
928
        datasets = [dataset]
929
        expected_splits = [split]
930
    for dataset, expected_split in zip(datasets, expected_splits):
931
        assert dataset.split == expected_split
932
        assert len(dataset) == expected_dataset_length
933
        assert dataset.features == Features({"text": Value("string")})
934
        dataset.column_names == ["text"]
935

936

937
@pytest.mark.parametrize("in_memory", [False, True])
938
def test_generator_based_builder_as_dataset(in_memory, tmp_path):
939
    cache_dir = tmp_path / "data"
940
    cache_dir.mkdir()
941
    cache_dir = str(cache_dir)
942
    builder = DummyGeneratorBasedBuilder(cache_dir=cache_dir)
943
    builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
944
    with assert_arrow_memory_increases() if in_memory else assert_arrow_memory_doesnt_increase():
945
        dataset = builder.as_dataset("train", in_memory=in_memory)
946
    assert dataset.data.to_pydict() == {"text": ["foo"] * 100}
947

948

949
@pytest.mark.parametrize(
950
    "writer_batch_size, default_writer_batch_size, expected_chunks", [(None, None, 1), (None, 5, 20), (10, None, 10)]
951
)
952
def test_custom_writer_batch_size(tmp_path, writer_batch_size, default_writer_batch_size, expected_chunks):
953
    cache_dir = str(tmp_path)
954
    if default_writer_batch_size:
955
        DummyGeneratorBasedBuilder.DEFAULT_WRITER_BATCH_SIZE = default_writer_batch_size
956
    builder = DummyGeneratorBasedBuilder(cache_dir=cache_dir, writer_batch_size=writer_batch_size)
957
    assert builder._writer_batch_size == (writer_batch_size or default_writer_batch_size)
958
    builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
959
    dataset = builder.as_dataset("train")
960
    assert len(dataset.data[0].chunks) == expected_chunks
961

962

963
def test_builder_as_streaming_dataset(tmp_path):
964
    dummy_builder = DummyGeneratorBasedBuilder(cache_dir=str(tmp_path))
965
    check_streaming(dummy_builder)
966
    dsets = dummy_builder.as_streaming_dataset()
967
    assert isinstance(dsets, IterableDatasetDict)
968
    assert isinstance(dsets["train"], IterableDataset)
969
    assert len(list(dsets["train"])) == 100
970
    dset = dummy_builder.as_streaming_dataset(split="train")
971
    assert isinstance(dset, IterableDataset)
972
    assert len(list(dset)) == 100
973

974

975
@require_beam
976
def test_beam_based_builder_as_streaming_dataset(tmp_path):
977
    builder = DummyBeamBasedBuilder(cache_dir=tmp_path)
978
    check_streaming(builder)
979
    with pytest.raises(DatasetNotOnHfGcsError):
980
        builder.as_streaming_dataset()
981

982

983
def _run_test_builder_streaming_works_in_subprocesses(builder):
984
    check_streaming(builder)
985
    dset = builder.as_streaming_dataset(split="train")
986
    assert isinstance(dset, IterableDataset)
987
    assert len(list(dset)) == 100
988

989

990
def test_builder_streaming_works_in_subprocess(tmp_path):
991
    dummy_builder = DummyGeneratorBasedBuilder(cache_dir=str(tmp_path))
992
    p = Process(target=_run_test_builder_streaming_works_in_subprocesses, args=(dummy_builder,))
993
    p.start()
994
    p.join()
995

996

997
class DummyBuilderWithVersion(GeneratorBasedBuilder):
998
    VERSION = "2.0.0"
999

1000
    def _info(self):
1001
        return DatasetInfo(features=Features({"text": Value("string")}))
1002

1003
    def _split_generators(self, dl_manager):
1004
        pass
1005

1006
    def _generate_examples(self):
1007
        pass
1008

1009

1010
class DummyBuilderWithBuilderConfigs(GeneratorBasedBuilder):
1011
    BUILDER_CONFIGS = [BuilderConfig(name="custom", version="2.0.0")]
1012

1013
    def _info(self):
1014
        return DatasetInfo(features=Features({"text": Value("string")}))
1015

1016
    def _split_generators(self, dl_manager):
1017
        pass
1018

1019
    def _generate_examples(self):
1020
        pass
1021

1022

1023
class CustomBuilderConfig(BuilderConfig):
1024
    def __init__(self, date=None, language=None, version="2.0.0", **kwargs):
1025
        name = f"{date}.{language}"
1026
        super().__init__(name=name, version=version, **kwargs)
1027
        self.date = date
1028
        self.language = language
1029

1030

1031
class DummyBuilderWithCustomBuilderConfigs(GeneratorBasedBuilder):
1032
    BUILDER_CONFIGS = [CustomBuilderConfig(date="20220501", language="en")]
1033
    BUILDER_CONFIG_CLASS = CustomBuilderConfig
1034

1035
    def _info(self):
1036
        return DatasetInfo(features=Features({"text": Value("string")}))
1037

1038
    def _split_generators(self, dl_manager):
1039
        pass
1040

1041
    def _generate_examples(self):
1042
        pass
1043

1044

1045
@pytest.mark.parametrize(
1046
    "builder_class, kwargs",
1047
    [
1048
        (DummyBuilderWithVersion, {}),
1049
        (DummyBuilderWithBuilderConfigs, {"config_name": "custom"}),
1050
        (DummyBuilderWithCustomBuilderConfigs, {"config_name": "20220501.en"}),
1051
        (DummyBuilderWithCustomBuilderConfigs, {"date": "20220501", "language": "ca"}),
1052
    ],
1053
)
1054
def test_builder_config_version(builder_class, kwargs, tmp_path):
1055
    cache_dir = str(tmp_path)
1056
    builder = builder_class(cache_dir=cache_dir, **kwargs)
1057
    assert builder.config.version == "2.0.0"
1058

1059

1060
def test_builder_download_and_prepare_with_absolute_output_dir(tmp_path):
1061
    builder = DummyGeneratorBasedBuilder()
1062
    output_dir = str(tmp_path)
1063
    builder.download_and_prepare(output_dir)
1064
    assert builder._output_dir.startswith(tmp_path.resolve().as_posix())
1065
    assert os.path.exists(os.path.join(output_dir, "dataset_info.json"))
1066
    assert os.path.exists(os.path.join(output_dir, f"{builder.dataset_name}-train.arrow"))
1067
    assert not os.path.exists(os.path.join(output_dir + ".incomplete"))
1068

1069

1070
def test_builder_download_and_prepare_with_relative_output_dir():
1071
    with set_current_working_directory_to_temp_dir():
1072
        builder = DummyGeneratorBasedBuilder()
1073
        output_dir = "test-out"
1074
        builder.download_and_prepare(output_dir)
1075
        assert Path(builder._output_dir).resolve().as_posix().startswith(Path(output_dir).resolve().as_posix())
1076
        assert os.path.exists(os.path.join(output_dir, "dataset_info.json"))
1077
        assert os.path.exists(os.path.join(output_dir, f"{builder.dataset_name}-train.arrow"))
1078
        assert not os.path.exists(os.path.join(output_dir + ".incomplete"))
1079

1080

1081
def test_builder_with_filesystem_download_and_prepare(tmp_path, mockfs):
1082
    builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path)
1083
    builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options)
1084
    assert builder._output_dir.startswith("mock://my_dataset")
1085
    assert is_local_path(builder._cache_downloaded_dir)
1086
    assert isinstance(builder._fs, type(mockfs))
1087
    assert builder._fs.storage_options == mockfs.storage_options
1088
    assert mockfs.exists("my_dataset/dataset_info.json")
1089
    assert mockfs.exists(f"my_dataset/{builder.dataset_name}-train.arrow")
1090
    assert not mockfs.exists("my_dataset.incomplete")
1091

1092

1093
def test_builder_with_filesystem_download_and_prepare_reload(tmp_path, mockfs, caplog):
1094
    builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path)
1095
    mockfs.makedirs("my_dataset")
1096
    DatasetInfo().write_to_directory("mock://my_dataset", storage_options=mockfs.storage_options)
1097
    mockfs.touch(f"my_dataset/{builder.dataset_name}-train.arrow")
1098
    caplog.clear()
1099
    with caplog.at_level(INFO, logger=get_logger().name):
1100
        builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options)
1101
    assert "Found cached dataset" in caplog.text
1102

1103

1104
def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path):
1105
    builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path)
1106
    builder.download_and_prepare(file_format="parquet")
1107
    assert builder.info.splits["train"].num_examples == 100
1108
    parquet_path = os.path.join(
1109
        tmp_path, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.parquet"
1110
    )
1111
    assert os.path.exists(parquet_path)
1112
    assert pq.ParquetFile(parquet_path) is not None
1113

1114

1115
def test_generator_based_builder_download_and_prepare_sharded(tmp_path):
1116
    writer_batch_size = 25
1117
    builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size)
1118
    with patch("datasets.config.MAX_SHARD_SIZE", 1):  # one batch per shard
1119
        builder.download_and_prepare(file_format="parquet")
1120
    expected_num_shards = 100 // writer_batch_size
1121
    assert builder.info.splits["train"].num_examples == 100
1122
    parquet_path = os.path.join(
1123
        tmp_path,
1124
        builder.dataset_name,
1125
        "default",
1126
        "0.0.0",
1127
        f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.parquet",
1128
    )
1129
    assert os.path.exists(parquet_path)
1130
    parquet_files = [
1131
        pq.ParquetFile(parquet_path)
1132
        for parquet_path in Path(tmp_path).rglob(
1133
            f"{builder.dataset_name}-train-*-of-{expected_num_shards:05d}.parquet"
1134
        )
1135
    ]
1136
    assert len(parquet_files) == expected_num_shards
1137
    assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100
1138

1139

1140
def test_generator_based_builder_download_and_prepare_with_max_shard_size(tmp_path):
1141
    writer_batch_size = 25
1142
    builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size)
1143
    builder.download_and_prepare(file_format="parquet", max_shard_size=1)  # one batch per shard
1144
    expected_num_shards = 100 // writer_batch_size
1145
    assert builder.info.splits["train"].num_examples == 100
1146
    parquet_path = os.path.join(
1147
        tmp_path,
1148
        builder.dataset_name,
1149
        "default",
1150
        "0.0.0",
1151
        f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.parquet",
1152
    )
1153
    assert os.path.exists(parquet_path)
1154
    parquet_files = [
1155
        pq.ParquetFile(parquet_path)
1156
        for parquet_path in Path(tmp_path).rglob(
1157
            f"{builder.dataset_name}-train-*-of-{expected_num_shards:05d}.parquet"
1158
        )
1159
    ]
1160
    assert len(parquet_files) == expected_num_shards
1161
    assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100
1162

1163

1164
def test_generator_based_builder_download_and_prepare_with_num_proc(tmp_path):
1165
    builder = DummyGeneratorBasedBuilderWithShards(cache_dir=tmp_path)
1166
    builder.download_and_prepare(num_proc=2)
1167
    expected_num_shards = 2
1168
    assert builder.info.splits["train"].num_examples == 400
1169
    assert builder.info.splits["train"].shard_lengths == [200, 200]
1170
    arrow_path = os.path.join(
1171
        tmp_path,
1172
        builder.dataset_name,
1173
        "default",
1174
        "0.0.0",
1175
        f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.arrow",
1176
    )
1177
    assert os.path.exists(arrow_path)
1178
    ds = builder.as_dataset("train")
1179
    assert len(ds) == 400
1180
    assert ds.to_dict() == {
1181
        "id": [i for _ in range(4) for i in range(100)],
1182
        "filepath": [f"data{i}.txt" for i in range(4) for _ in range(100)],
1183
    }
1184

1185

1186
@pytest.mark.parametrize(
1187
    "num_proc, expectation", [(None, does_not_raise()), (1, does_not_raise()), (2, pytest.raises(RuntimeError))]
1188
)
1189
def test_generator_based_builder_download_and_prepare_with_ambiguous_shards(num_proc, expectation, tmp_path):
1190
    builder = DummyGeneratorBasedBuilderWithAmbiguousShards(cache_dir=tmp_path)
1191
    with expectation:
1192
        builder.download_and_prepare(num_proc=num_proc)
1193

1194

1195
def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path):
1196
    builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
1197
    builder.download_and_prepare(file_format="parquet")
1198
    assert builder.info.splits["train"].num_examples == 100
1199
    parquet_path = os.path.join(
1200
        tmp_path, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.parquet"
1201
    )
1202
    assert os.path.exists(parquet_path)
1203
    assert pq.ParquetFile(parquet_path) is not None
1204

1205

1206
def test_arrow_based_builder_download_and_prepare_sharded(tmp_path):
1207
    builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
1208
    with patch("datasets.config.MAX_SHARD_SIZE", 1):  # one batch per shard
1209
        builder.download_and_prepare(file_format="parquet")
1210
    expected_num_shards = 10
1211
    assert builder.info.splits["train"].num_examples == 100
1212
    parquet_path = os.path.join(
1213
        tmp_path,
1214
        builder.dataset_name,
1215
        "default",
1216
        "0.0.0",
1217
        f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.parquet",
1218
    )
1219
    assert os.path.exists(parquet_path)
1220
    parquet_files = [
1221
        pq.ParquetFile(parquet_path)
1222
        for parquet_path in Path(tmp_path).rglob(
1223
            f"{builder.dataset_name}-train-*-of-{expected_num_shards:05d}.parquet"
1224
        )
1225
    ]
1226
    assert len(parquet_files) == expected_num_shards
1227
    assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100
1228

1229

1230
def test_arrow_based_builder_download_and_prepare_with_max_shard_size(tmp_path):
1231
    builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
1232
    builder.download_and_prepare(file_format="parquet", max_shard_size=1)  # one table per shard
1233
    expected_num_shards = 10
1234
    assert builder.info.splits["train"].num_examples == 100
1235
    parquet_path = os.path.join(
1236
        tmp_path,
1237
        builder.dataset_name,
1238
        "default",
1239
        "0.0.0",
1240
        f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.parquet",
1241
    )
1242
    assert os.path.exists(parquet_path)
1243
    parquet_files = [
1244
        pq.ParquetFile(parquet_path)
1245
        for parquet_path in Path(tmp_path).rglob(
1246
            f"{builder.dataset_name}-train-*-of-{expected_num_shards:05d}.parquet"
1247
        )
1248
    ]
1249
    assert len(parquet_files) == expected_num_shards
1250
    assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100
1251

1252

1253
def test_arrow_based_builder_download_and_prepare_with_num_proc(tmp_path):
1254
    builder = DummyArrowBasedBuilderWithShards(cache_dir=tmp_path)
1255
    builder.download_and_prepare(num_proc=2)
1256
    expected_num_shards = 2
1257
    assert builder.info.splits["train"].num_examples == 400
1258
    assert builder.info.splits["train"].shard_lengths == [200, 200]
1259
    arrow_path = os.path.join(
1260
        tmp_path,
1261
        builder.dataset_name,
1262
        "default",
1263
        "0.0.0",
1264
        f"{builder.dataset_name}-train-00000-of-{expected_num_shards:05d}.arrow",
1265
    )
1266
    assert os.path.exists(arrow_path)
1267
    ds = builder.as_dataset("train")
1268
    assert len(ds) == 400
1269
    assert ds.to_dict() == {
1270
        "id": [i for _ in range(4) for i in range(100)],
1271
        "filepath": [f"data{i}.txt" for i in range(4) for _ in range(100)],
1272
    }
1273

1274

1275
@pytest.mark.parametrize(
1276
    "num_proc, expectation", [(None, does_not_raise()), (1, does_not_raise()), (2, pytest.raises(RuntimeError))]
1277
)
1278
def test_arrow_based_builder_download_and_prepare_with_ambiguous_shards(num_proc, expectation, tmp_path):
1279
    builder = DummyArrowBasedBuilderWithAmbiguousShards(cache_dir=tmp_path)
1280
    with expectation:
1281
        builder.download_and_prepare(num_proc=num_proc)
1282

1283

1284
@require_beam
1285
def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path):
1286
    builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner")
1287
    builder.download_and_prepare(file_format="parquet")
1288
    assert builder.info.splits["train"].num_examples == 100
1289
    parquet_path = os.path.join(
1290
        tmp_path, builder.dataset_name, "default", "0.0.0", f"{builder.dataset_name}-train.parquet"
1291
    )
1292
    assert os.path.exists(parquet_path)
1293
    assert pq.ParquetFile(parquet_path) is not None
1294

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.