datasets

Форк
0
/
test_load.py 
1770 строк · 82.0 Кб
1
import importlib
2
import os
3
import pickle
4
import shutil
5
import tempfile
6
import time
7
from hashlib import sha256
8
from multiprocessing import Pool
9
from pathlib import Path
10
from unittest import TestCase
11
from unittest.mock import patch
12

13
import dill
14
import pyarrow as pa
15
import pytest
16
import requests
17

18
import datasets
19
from datasets import config, load_dataset, load_from_disk
20
from datasets.arrow_dataset import Dataset
21
from datasets.arrow_writer import ArrowWriter
22
from datasets.builder import DatasetBuilder
23
from datasets.config import METADATA_CONFIGS_FIELD
24
from datasets.data_files import DataFilesDict, DataFilesPatternsDict
25
from datasets.dataset_dict import DatasetDict, IterableDatasetDict
26
from datasets.download.download_config import DownloadConfig
27
from datasets.exceptions import DatasetNotFoundError
28
from datasets.features import Features, Image, Value
29
from datasets.iterable_dataset import IterableDataset
30
from datasets.load import (
31
    CachedDatasetModuleFactory,
32
    CachedMetricModuleFactory,
33
    GithubMetricModuleFactory,
34
    HubDatasetModuleFactoryWithoutScript,
35
    HubDatasetModuleFactoryWithParquetExport,
36
    HubDatasetModuleFactoryWithScript,
37
    LocalDatasetModuleFactoryWithoutScript,
38
    LocalDatasetModuleFactoryWithScript,
39
    LocalMetricModuleFactory,
40
    PackagedDatasetModuleFactory,
41
    infer_module_for_data_files_list,
42
    infer_module_for_data_files_list_in_archives,
43
    load_dataset_builder,
44
    resolve_trust_remote_code,
45
)
46
from datasets.packaged_modules.audiofolder.audiofolder import AudioFolder, AudioFolderConfig
47
from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder, ImageFolderConfig
48
from datasets.packaged_modules.parquet.parquet import ParquetConfig
49
from datasets.utils import _datasets_server
50
from datasets.utils.logging import INFO, get_logger
51

52
from .utils import (
53
    OfflineSimulationMode,
54
    assert_arrow_memory_doesnt_increase,
55
    assert_arrow_memory_increases,
56
    offline,
57
    require_pil,
58
    require_sndfile,
59
    set_current_working_directory_to_temp_dir,
60
)
61

62

63
DATASET_LOADING_SCRIPT_NAME = "__dummy_dataset1__"
64

65
DATASET_LOADING_SCRIPT_CODE = """
66
import os
67

68
import datasets
69
from datasets import DatasetInfo, Features, Split, SplitGenerator, Value
70

71

72
class __DummyDataset1__(datasets.GeneratorBasedBuilder):
73

74
    def _info(self) -> DatasetInfo:
75
        return DatasetInfo(features=Features({"text": Value("string")}))
76

77
    def _split_generators(self, dl_manager):
78
        return [
79
            SplitGenerator(Split.TRAIN, gen_kwargs={"filepath": os.path.join(dl_manager.manual_dir, "train.txt")}),
80
            SplitGenerator(Split.TEST, gen_kwargs={"filepath": os.path.join(dl_manager.manual_dir, "test.txt")}),
81
        ]
82

83
    def _generate_examples(self, filepath, **kwargs):
84
        with open(filepath, "r", encoding="utf-8") as f:
85
            for i, line in enumerate(f):
86
                yield i, {"text": line.strip()}
87
"""
88

89
SAMPLE_DATASET_IDENTIFIER = "hf-internal-testing/dataset_with_script"  # has dataset script and also a parquet export
90
SAMPLE_DATASET_IDENTIFIER2 = "hf-internal-testing/dataset_with_data_files"  # only has data files
91
SAMPLE_DATASET_IDENTIFIER3 = "hf-internal-testing/multi_dir_dataset"  # has multiple data directories
92
SAMPLE_DATASET_IDENTIFIER4 = "hf-internal-testing/imagefolder_with_metadata"  # imagefolder with a metadata file outside of the train/test directories
93
SAMPLE_DATASET_IDENTIFIER5 = "hf-internal-testing/imagefolder_with_metadata_no_splits"  # imagefolder with a metadata file and no default split names in data files
94
SAMPLE_NOT_EXISTING_DATASET_IDENTIFIER = "hf-internal-testing/_dummy"
95
SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST = "_dummy"
96
SAMPLE_DATASET_NO_CONFIGS_IN_METADATA = "hf-internal-testing/audiofolder_no_configs_in_metadata"
97
SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_single_config_in_metadata"
98
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_two_configs_in_metadata"
99
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_WITH_DEFAULT = (
100
    "hf-internal-testing/audiofolder_two_configs_in_metadata_with_default"
101
)
102

103

104
METRIC_LOADING_SCRIPT_NAME = "__dummy_metric1__"
105

106
METRIC_LOADING_SCRIPT_CODE = """
107
import datasets
108
from datasets import MetricInfo, Features, Value
109

110

111
class __DummyMetric1__(datasets.Metric):
112

113
    def _info(self):
114
        return MetricInfo(features=Features({"predictions": Value("int"), "references": Value("int")}))
115

116
    def _compute(self, predictions, references):
117
        return {"__dummy_metric1__": sum(int(p == r) for p, r in zip(predictions, references))}
118
"""
119

120

121
@pytest.fixture
122
def data_dir(tmp_path):
123
    data_dir = tmp_path / "data_dir"
124
    data_dir.mkdir()
125
    with open(data_dir / "train.txt", "w") as f:
126
        f.write("foo\n" * 10)
127
    with open(data_dir / "test.txt", "w") as f:
128
        f.write("bar\n" * 10)
129
    return str(data_dir)
130

131

132
@pytest.fixture
133
def data_dir_with_arrow(tmp_path):
134
    data_dir = tmp_path / "data_dir"
135
    data_dir.mkdir()
136
    output_train = os.path.join(data_dir, "train.arrow")
137
    with ArrowWriter(path=output_train) as writer:
138
        writer.write_table(pa.Table.from_pydict({"col_1": ["foo"] * 10}))
139
        num_examples, num_bytes = writer.finalize()
140
    assert num_examples == 10
141
    assert num_bytes > 0
142
    output_test = os.path.join(data_dir, "test.arrow")
143
    with ArrowWriter(path=output_test) as writer:
144
        writer.write_table(pa.Table.from_pydict({"col_1": ["bar"] * 10}))
145
        num_examples, num_bytes = writer.finalize()
146
    assert num_examples == 10
147
    assert num_bytes > 0
148
    return str(data_dir)
149

150

151
@pytest.fixture
152
def data_dir_with_metadata(tmp_path):
153
    data_dir = tmp_path / "data_dir_with_metadata"
154
    data_dir.mkdir()
155
    with open(data_dir / "train.jpg", "wb") as f:
156
        f.write(b"train_image_bytes")
157
    with open(data_dir / "test.jpg", "wb") as f:
158
        f.write(b"test_image_bytes")
159
    with open(data_dir / "metadata.jsonl", "w") as f:
160
        f.write(
161
            """\
162
        {"file_name": "train.jpg", "caption": "Cool tran image"}
163
        {"file_name": "test.jpg", "caption": "Cool test image"}
164
        """
165
        )
166
    return str(data_dir)
167

168

169
@pytest.fixture
170
def data_dir_with_single_config_in_metadata(tmp_path):
171
    data_dir = tmp_path / "data_dir_with_one_default_config_in_metadata"
172

173
    cats_data_dir = data_dir / "cats"
174
    cats_data_dir.mkdir(parents=True)
175
    dogs_data_dir = data_dir / "dogs"
176
    dogs_data_dir.mkdir(parents=True)
177

178
    with open(cats_data_dir / "cat.jpg", "wb") as f:
179
        f.write(b"this_is_a_cat_image_bytes")
180
    with open(dogs_data_dir / "dog.jpg", "wb") as f:
181
        f.write(b"this_is_a_dog_image_bytes")
182
    with open(data_dir / "README.md", "w") as f:
183
        f.write(
184
            f"""\
185
---
186
{METADATA_CONFIGS_FIELD}:
187
  - config_name: custom
188
    drop_labels: true
189
---
190
        """
191
        )
192
    return str(data_dir)
193

194

195
@pytest.fixture
196
def data_dir_with_config_and_data_files(tmp_path):
197
    data_dir = tmp_path / "data_dir_with_config_and_data_files"
198

199
    cats_data_dir = data_dir / "data" / "cats"
200
    cats_data_dir.mkdir(parents=True)
201
    dogs_data_dir = data_dir / "data" / "dogs"
202
    dogs_data_dir.mkdir(parents=True)
203

204
    with open(cats_data_dir / "cat.jpg", "wb") as f:
205
        f.write(b"this_is_a_cat_image_bytes")
206
    with open(dogs_data_dir / "dog.jpg", "wb") as f:
207
        f.write(b"this_is_a_dog_image_bytes")
208
    with open(data_dir / "README.md", "w") as f:
209
        f.write(
210
            f"""\
211
---
212
{METADATA_CONFIGS_FIELD}:
213
  - config_name: custom
214
    data_files: "data/**/*.jpg"
215
---
216
        """
217
        )
218
    return str(data_dir)
219

220

221
@pytest.fixture
222
def data_dir_with_two_config_in_metadata(tmp_path):
223
    data_dir = tmp_path / "data_dir_with_two_configs_in_metadata"
224
    cats_data_dir = data_dir / "cats"
225
    cats_data_dir.mkdir(parents=True)
226
    dogs_data_dir = data_dir / "dogs"
227
    dogs_data_dir.mkdir(parents=True)
228

229
    with open(cats_data_dir / "cat.jpg", "wb") as f:
230
        f.write(b"this_is_a_cat_image_bytes")
231
    with open(dogs_data_dir / "dog.jpg", "wb") as f:
232
        f.write(b"this_is_a_dog_image_bytes")
233

234
    with open(data_dir / "README.md", "w") as f:
235
        f.write(
236
            f"""\
237
---
238
{METADATA_CONFIGS_FIELD}:
239
  - config_name: "v1"
240
    drop_labels: true
241
    default: true
242
  - config_name: "v2"
243
    drop_labels: false
244
---
245
        """
246
        )
247
    return str(data_dir)
248

249

250
@pytest.fixture
251
def data_dir_with_data_dir_configs_in_metadata(tmp_path):
252
    data_dir = tmp_path / "data_dir_with_two_configs_in_metadata"
253
    cats_data_dir = data_dir / "cats"
254
    cats_data_dir.mkdir(parents=True)
255
    dogs_data_dir = data_dir / "dogs"
256
    dogs_data_dir.mkdir(parents=True)
257

258
    with open(cats_data_dir / "cat.jpg", "wb") as f:
259
        f.write(b"this_is_a_cat_image_bytes")
260
    with open(dogs_data_dir / "dog.jpg", "wb") as f:
261
        f.write(b"this_is_a_dog_image_bytes")
262

263

264
@pytest.fixture
265
def sub_data_dirs(tmp_path):
266
    data_dir2 = tmp_path / "data_dir2"
267
    relative_subdir1 = "subdir1"
268
    sub_data_dir1 = data_dir2 / relative_subdir1
269
    sub_data_dir1.mkdir(parents=True)
270
    with open(sub_data_dir1 / "train.txt", "w") as f:
271
        f.write("foo\n" * 10)
272
    with open(sub_data_dir1 / "test.txt", "w") as f:
273
        f.write("bar\n" * 10)
274

275
    relative_subdir2 = "subdir2"
276
    sub_data_dir2 = tmp_path / data_dir2 / relative_subdir2
277
    sub_data_dir2.mkdir(parents=True)
278
    with open(sub_data_dir2 / "train.txt", "w") as f:
279
        f.write("foo\n" * 10)
280
    with open(sub_data_dir2 / "test.txt", "w") as f:
281
        f.write("bar\n" * 10)
282

283
    return str(data_dir2), relative_subdir1
284

285

286
@pytest.fixture
287
def complex_data_dir(tmp_path):
288
    data_dir = tmp_path / "complex_data_dir"
289
    data_dir.mkdir()
290
    (data_dir / "data").mkdir()
291
    with open(data_dir / "data" / "train.txt", "w") as f:
292
        f.write("foo\n" * 10)
293
    with open(data_dir / "data" / "test.txt", "w") as f:
294
        f.write("bar\n" * 10)
295
    with open(data_dir / "README.md", "w") as f:
296
        f.write("This is a readme")
297
    with open(data_dir / ".dummy", "w") as f:
298
        f.write("this is a dummy file that is not a data file")
299
    return str(data_dir)
300

301

302
@pytest.fixture
303
def dataset_loading_script_dir(tmp_path):
304
    script_name = DATASET_LOADING_SCRIPT_NAME
305
    script_dir = tmp_path / script_name
306
    script_dir.mkdir()
307
    script_path = script_dir / f"{script_name}.py"
308
    with open(script_path, "w") as f:
309
        f.write(DATASET_LOADING_SCRIPT_CODE)
310
    return str(script_dir)
311

312

313
@pytest.fixture
314
def dataset_loading_script_dir_readonly(tmp_path):
315
    script_name = DATASET_LOADING_SCRIPT_NAME
316
    script_dir = tmp_path / "readonly" / script_name
317
    script_dir.mkdir(parents=True)
318
    script_path = script_dir / f"{script_name}.py"
319
    with open(script_path, "w") as f:
320
        f.write(DATASET_LOADING_SCRIPT_CODE)
321
    dataset_loading_script_dir = str(script_dir)
322
    # Make this directory readonly
323
    os.chmod(dataset_loading_script_dir, 0o555)
324
    os.chmod(os.path.join(dataset_loading_script_dir, f"{script_name}.py"), 0o555)
325
    return dataset_loading_script_dir
326

327

328
@pytest.fixture
329
def metric_loading_script_dir(tmp_path):
330
    script_name = METRIC_LOADING_SCRIPT_NAME
331
    script_dir = tmp_path / script_name
332
    script_dir.mkdir()
333
    script_path = script_dir / f"{script_name}.py"
334
    with open(script_path, "w") as f:
335
        f.write(METRIC_LOADING_SCRIPT_CODE)
336
    return str(script_dir)
337

338

339
@pytest.mark.parametrize(
340
    "data_files, expected_module, expected_builder_kwargs",
341
    [
342
        (["train.csv"], "csv", {}),
343
        (["train.tsv"], "csv", {"sep": "\t"}),
344
        (["train.json"], "json", {}),
345
        (["train.jsonl"], "json", {}),
346
        (["train.parquet"], "parquet", {}),
347
        (["train.geoparquet"], "parquet", {}),
348
        (["train.gpq"], "parquet", {}),
349
        (["train.arrow"], "arrow", {}),
350
        (["train.txt"], "text", {}),
351
        (["uppercase.TXT"], "text", {}),
352
        (["unsupported.ext"], None, {}),
353
        ([""], None, {}),
354
    ],
355
)
356
def test_infer_module_for_data_files(data_files, expected_module, expected_builder_kwargs):
357
    module, builder_kwargs = infer_module_for_data_files_list(data_files)
358
    assert module == expected_module
359
    assert builder_kwargs == expected_builder_kwargs
360

361

362
@pytest.mark.parametrize(
363
    "data_file, expected_module",
364
    [
365
        ("zip_csv_path", "csv"),
366
        ("zip_csv_with_dir_path", "csv"),
367
        ("zip_uppercase_csv_path", "csv"),
368
        ("zip_unsupported_ext_path", None),
369
    ],
370
)
371
def test_infer_module_for_data_files_in_archives(
372
    data_file, expected_module, zip_csv_path, zip_csv_with_dir_path, zip_uppercase_csv_path, zip_unsupported_ext_path
373
):
374
    data_file_paths = {
375
        "zip_csv_path": zip_csv_path,
376
        "zip_csv_with_dir_path": zip_csv_with_dir_path,
377
        "zip_uppercase_csv_path": zip_uppercase_csv_path,
378
        "zip_unsupported_ext_path": zip_unsupported_ext_path,
379
    }
380
    data_files = [str(data_file_paths[data_file])]
381
    inferred_module, _ = infer_module_for_data_files_list_in_archives(data_files)
382
    assert inferred_module == expected_module
383

384

385
class ModuleFactoryTest(TestCase):
386
    @pytest.fixture(autouse=True)
387
    def inject_fixtures(
388
        self,
389
        jsonl_path,
390
        data_dir,
391
        data_dir_with_metadata,
392
        data_dir_with_single_config_in_metadata,
393
        data_dir_with_config_and_data_files,
394
        data_dir_with_two_config_in_metadata,
395
        sub_data_dirs,
396
        dataset_loading_script_dir,
397
        metric_loading_script_dir,
398
    ):
399
        self._jsonl_path = jsonl_path
400
        self._data_dir = data_dir
401
        self._data_dir_with_metadata = data_dir_with_metadata
402
        self._data_dir_with_single_config_in_metadata = data_dir_with_single_config_in_metadata
403
        self._data_dir_with_config_and_data_files = data_dir_with_config_and_data_files
404
        self._data_dir_with_two_config_in_metadata = data_dir_with_two_config_in_metadata
405
        self._data_dir2 = sub_data_dirs[0]
406
        self._sub_data_dir = sub_data_dirs[1]
407
        self._dataset_loading_script_dir = dataset_loading_script_dir
408
        self._metric_loading_script_dir = metric_loading_script_dir
409

410
    def setUp(self):
411
        self.hf_modules_cache = tempfile.mkdtemp()
412
        self.cache_dir = tempfile.mkdtemp()
413
        self.download_config = DownloadConfig(cache_dir=self.cache_dir)
414
        self.dynamic_modules_path = datasets.load.init_dynamic_modules(
415
            name="test_datasets_modules_" + os.path.basename(self.hf_modules_cache),
416
            hf_modules_cache=self.hf_modules_cache,
417
        )
418

419
    def test_HubDatasetModuleFactoryWithScript_dont_trust_remote_code(self):
420
        # "lhoestq/test" has a dataset script
421
        factory = HubDatasetModuleFactoryWithScript(
422
            "lhoestq/test", download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
423
        )
424
        with patch.object(config, "HF_DATASETS_TRUST_REMOTE_CODE", None):  # this will be the default soon
425
            self.assertRaises(ValueError, factory.get_module)
426
        factory = HubDatasetModuleFactoryWithScript(
427
            "lhoestq/test",
428
            download_config=self.download_config,
429
            dynamic_modules_path=self.dynamic_modules_path,
430
            trust_remote_code=False,
431
        )
432
        self.assertRaises(ValueError, factory.get_module)
433

434
    def test_HubDatasetModuleFactoryWithScript_with_github_dataset(self):
435
        # "wmt_t2t" has additional imports (internal)
436
        factory = HubDatasetModuleFactoryWithScript(
437
            "wmt_t2t", download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
438
        )
439
        module_factory_result = factory.get_module()
440
        assert importlib.import_module(module_factory_result.module_path) is not None
441
        assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT)
442

443
    def test_GithubMetricModuleFactory_with_internal_import(self):
444
        # "squad_v2" requires additional imports (internal)
445
        factory = GithubMetricModuleFactory(
446
            "squad_v2", download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
447
        )
448
        module_factory_result = factory.get_module()
449
        assert importlib.import_module(module_factory_result.module_path) is not None
450

451
    @pytest.mark.filterwarnings("ignore:GithubMetricModuleFactory is deprecated:FutureWarning")
452
    def test_GithubMetricModuleFactory_with_external_import(self):
453
        # "bleu" requires additional imports (external from github)
454
        factory = GithubMetricModuleFactory(
455
            "bleu", download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
456
        )
457
        module_factory_result = factory.get_module()
458
        assert importlib.import_module(module_factory_result.module_path) is not None
459

460
    def test_LocalMetricModuleFactory(self):
461
        path = os.path.join(self._metric_loading_script_dir, f"{METRIC_LOADING_SCRIPT_NAME}.py")
462
        factory = LocalMetricModuleFactory(
463
            path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
464
        )
465
        module_factory_result = factory.get_module()
466
        assert importlib.import_module(module_factory_result.module_path) is not None
467

468
    def test_LocalDatasetModuleFactoryWithScript(self):
469
        path = os.path.join(self._dataset_loading_script_dir, f"{DATASET_LOADING_SCRIPT_NAME}.py")
470
        factory = LocalDatasetModuleFactoryWithScript(
471
            path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
472
        )
473
        module_factory_result = factory.get_module()
474
        assert importlib.import_module(module_factory_result.module_path) is not None
475
        assert os.path.isdir(module_factory_result.builder_kwargs["base_path"])
476

477
    def test_LocalDatasetModuleFactoryWithScript_dont_trust_remote_code(self):
478
        path = os.path.join(self._dataset_loading_script_dir, f"{DATASET_LOADING_SCRIPT_NAME}.py")
479
        factory = LocalDatasetModuleFactoryWithScript(
480
            path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
481
        )
482
        with patch.object(config, "HF_DATASETS_TRUST_REMOTE_CODE", None):  # this will be the default soon
483
            self.assertRaises(ValueError, factory.get_module)
484
        factory = LocalDatasetModuleFactoryWithScript(
485
            path,
486
            download_config=self.download_config,
487
            dynamic_modules_path=self.dynamic_modules_path,
488
            trust_remote_code=False,
489
        )
490
        self.assertRaises(ValueError, factory.get_module)
491

492
    def test_LocalDatasetModuleFactoryWithoutScript(self):
493
        factory = LocalDatasetModuleFactoryWithoutScript(self._data_dir)
494
        module_factory_result = factory.get_module()
495
        assert importlib.import_module(module_factory_result.module_path) is not None
496
        assert os.path.isdir(module_factory_result.builder_kwargs["base_path"])
497

498
    def test_LocalDatasetModuleFactoryWithoutScript_with_data_dir(self):
499
        factory = LocalDatasetModuleFactoryWithoutScript(self._data_dir2, data_dir=self._sub_data_dir)
500
        module_factory_result = factory.get_module()
501
        assert importlib.import_module(module_factory_result.module_path) is not None
502
        builder_config = module_factory_result.builder_configs_parameters.builder_configs[0]
503
        assert (
504
            builder_config.data_files is not None
505
            and len(builder_config.data_files["train"]) == 1
506
            and len(builder_config.data_files["test"]) == 1
507
        )
508
        assert all(
509
            self._sub_data_dir in Path(data_file).parts
510
            for data_file in builder_config.data_files["train"] + builder_config.data_files["test"]
511
        )
512

513
    def test_LocalDatasetModuleFactoryWithoutScript_with_metadata(self):
514
        factory = LocalDatasetModuleFactoryWithoutScript(self._data_dir_with_metadata)
515
        module_factory_result = factory.get_module()
516
        assert importlib.import_module(module_factory_result.module_path) is not None
517
        builder_config = module_factory_result.builder_configs_parameters.builder_configs[0]
518
        assert (
519
            builder_config.data_files is not None
520
            and len(builder_config.data_files["train"]) > 0
521
            and len(builder_config.data_files["test"]) > 0
522
        )
523
        assert any(Path(data_file).name == "metadata.jsonl" for data_file in builder_config.data_files["train"])
524
        assert any(Path(data_file).name == "metadata.jsonl" for data_file in builder_config.data_files["test"])
525

526
    def test_LocalDatasetModuleFactoryWithoutScript_with_single_config_in_metadata(self):
527
        factory = LocalDatasetModuleFactoryWithoutScript(
528
            self._data_dir_with_single_config_in_metadata,
529
        )
530
        module_factory_result = factory.get_module()
531
        assert importlib.import_module(module_factory_result.module_path) is not None
532

533
        module_metadata_configs = module_factory_result.builder_configs_parameters.metadata_configs
534
        assert module_metadata_configs is not None
535
        assert len(module_metadata_configs) == 1
536
        assert next(iter(module_metadata_configs)) == "custom"
537
        assert "drop_labels" in next(iter(module_metadata_configs.values()))
538
        assert next(iter(module_metadata_configs.values()))["drop_labels"] is True
539

540
        module_builder_configs = module_factory_result.builder_configs_parameters.builder_configs
541
        assert module_builder_configs is not None
542
        assert len(module_builder_configs) == 1
543
        assert isinstance(module_builder_configs[0], ImageFolderConfig)
544
        assert module_builder_configs[0].name == "custom"
545
        assert module_builder_configs[0].data_files is not None
546
        assert isinstance(module_builder_configs[0].data_files, DataFilesPatternsDict)
547
        module_builder_configs[0]._resolve_data_files(self._data_dir_with_single_config_in_metadata, DownloadConfig())
548
        assert isinstance(module_builder_configs[0].data_files, DataFilesDict)
549
        assert len(module_builder_configs[0].data_files) == 1  # one train split
550
        assert len(module_builder_configs[0].data_files["train"]) == 2  # two files
551
        assert module_builder_configs[0].drop_labels is True  # parameter is passed from metadata
552

553
        # config named "default" is automatically considered to be a default config
554
        assert module_factory_result.builder_configs_parameters.default_config_name == "custom"
555

556
        # we don't pass config params to builder in builder_kwargs, they are stored in builder_configs directly
557
        assert "drop_labels" not in module_factory_result.builder_kwargs
558

559
    def test_LocalDatasetModuleFactoryWithoutScript_with_config_and_data_files(self):
560
        factory = LocalDatasetModuleFactoryWithoutScript(
561
            self._data_dir_with_config_and_data_files,
562
        )
563
        module_factory_result = factory.get_module()
564
        assert importlib.import_module(module_factory_result.module_path) is not None
565

566
        module_metadata_configs = module_factory_result.builder_configs_parameters.metadata_configs
567
        builder_kwargs = module_factory_result.builder_kwargs
568
        assert module_metadata_configs is not None
569
        assert len(module_metadata_configs) == 1
570
        assert next(iter(module_metadata_configs)) == "custom"
571
        assert "data_files" in next(iter(module_metadata_configs.values()))
572
        assert next(iter(module_metadata_configs.values()))["data_files"] == "data/**/*.jpg"
573
        assert "data_files" not in builder_kwargs
574

575
    def test_LocalDatasetModuleFactoryWithoutScript_data_dir_with_config_and_data_files(self):
576
        factory = LocalDatasetModuleFactoryWithoutScript(self._data_dir_with_config_and_data_files, data_dir="data")
577
        module_factory_result = factory.get_module()
578
        assert importlib.import_module(module_factory_result.module_path) is not None
579

580
        module_metadata_configs = module_factory_result.builder_configs_parameters.metadata_configs
581
        builder_kwargs = module_factory_result.builder_kwargs
582
        assert module_metadata_configs is not None
583
        assert len(module_metadata_configs) == 1
584
        assert next(iter(module_metadata_configs)) == "custom"
585
        assert "data_files" in next(iter(module_metadata_configs.values()))
586
        assert next(iter(module_metadata_configs.values()))["data_files"] == "data/**/*.jpg"
587
        assert "data_files" in builder_kwargs
588
        assert "train" in builder_kwargs["data_files"]
589
        assert len(builder_kwargs["data_files"]["train"]) == 2
590

591
    def test_LocalDatasetModuleFactoryWithoutScript_with_two_configs_in_metadata(self):
592
        factory = LocalDatasetModuleFactoryWithoutScript(
593
            self._data_dir_with_two_config_in_metadata,
594
        )
595
        module_factory_result = factory.get_module()
596
        assert importlib.import_module(module_factory_result.module_path) is not None
597

598
        module_metadata_configs = module_factory_result.builder_configs_parameters.metadata_configs
599
        assert module_metadata_configs is not None
600
        assert len(module_metadata_configs) == 2
601
        assert list(module_metadata_configs) == ["v1", "v2"]
602
        assert "drop_labels" in module_metadata_configs["v1"]
603
        assert module_metadata_configs["v1"]["drop_labels"] is True
604
        assert "drop_labels" in module_metadata_configs["v2"]
605
        assert module_metadata_configs["v2"]["drop_labels"] is False
606

607
        module_builder_configs = module_factory_result.builder_configs_parameters.builder_configs
608
        assert module_builder_configs is not None
609
        assert len(module_builder_configs) == 2
610
        module_builder_config_v1, module_builder_config_v2 = module_builder_configs
611
        assert module_builder_config_v1.name == "v1"
612
        assert module_builder_config_v2.name == "v2"
613
        assert isinstance(module_builder_config_v1, ImageFolderConfig)
614
        assert isinstance(module_builder_config_v2, ImageFolderConfig)
615
        assert isinstance(module_builder_config_v1.data_files, DataFilesPatternsDict)
616
        assert isinstance(module_builder_config_v2.data_files, DataFilesPatternsDict)
617
        module_builder_config_v1._resolve_data_files(self._data_dir_with_two_config_in_metadata, DownloadConfig())
618
        module_builder_config_v2._resolve_data_files(self._data_dir_with_two_config_in_metadata, DownloadConfig())
619
        assert isinstance(module_builder_config_v1.data_files, DataFilesDict)
620
        assert isinstance(module_builder_config_v2.data_files, DataFilesDict)
621
        assert sorted(module_builder_config_v1.data_files) == ["train"]
622
        assert len(module_builder_config_v1.data_files["train"]) == 2
623
        assert sorted(module_builder_config_v2.data_files) == ["train"]
624
        assert len(module_builder_config_v2.data_files["train"]) == 2
625
        assert module_builder_config_v1.drop_labels is True  # parameter is passed from metadata
626
        assert module_builder_config_v2.drop_labels is False  # parameter is passed from metadata
627

628
        assert (
629
            module_factory_result.builder_configs_parameters.default_config_name == "v1"
630
        )  # it's marked as a default one in yaml
631

632
        # we don't pass config params to builder in builder_kwargs, they are stored in builder_configs directly
633
        assert "drop_labels" not in module_factory_result.builder_kwargs
634

635
    def test_PackagedDatasetModuleFactory(self):
636
        factory = PackagedDatasetModuleFactory(
637
            "json", data_files=self._jsonl_path, download_config=self.download_config
638
        )
639
        module_factory_result = factory.get_module()
640
        assert importlib.import_module(module_factory_result.module_path) is not None
641

642
    def test_PackagedDatasetModuleFactory_with_data_dir(self):
643
        factory = PackagedDatasetModuleFactory("json", data_dir=self._data_dir, download_config=self.download_config)
644
        module_factory_result = factory.get_module()
645
        assert importlib.import_module(module_factory_result.module_path) is not None
646
        data_files = module_factory_result.builder_kwargs.get("data_files")
647
        assert data_files is not None and len(data_files["train"]) > 0 and len(data_files["test"]) > 0
648
        assert Path(data_files["train"][0]).parent.samefile(self._data_dir)
649
        assert Path(data_files["test"][0]).parent.samefile(self._data_dir)
650

651
    def test_PackagedDatasetModuleFactory_with_data_dir_and_metadata(self):
652
        factory = PackagedDatasetModuleFactory(
653
            "imagefolder", data_dir=self._data_dir_with_metadata, download_config=self.download_config
654
        )
655
        module_factory_result = factory.get_module()
656
        assert importlib.import_module(module_factory_result.module_path) is not None
657
        data_files = module_factory_result.builder_kwargs.get("data_files")
658
        assert data_files is not None and len(data_files["train"]) > 0 and len(data_files["test"]) > 0
659
        assert Path(data_files["train"][0]).parent.samefile(self._data_dir_with_metadata)
660
        assert Path(data_files["test"][0]).parent.samefile(self._data_dir_with_metadata)
661
        assert any(Path(data_file).name == "metadata.jsonl" for data_file in data_files["train"])
662
        assert any(Path(data_file).name == "metadata.jsonl" for data_file in data_files["test"])
663

664
    @pytest.mark.integration
665
    def test_HubDatasetModuleFactoryWithoutScript(self):
666
        factory = HubDatasetModuleFactoryWithoutScript(
667
            SAMPLE_DATASET_IDENTIFIER2, download_config=self.download_config
668
        )
669
        module_factory_result = factory.get_module()
670
        assert importlib.import_module(module_factory_result.module_path) is not None
671
        assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT)
672

673
    @pytest.mark.integration
674
    def test_HubDatasetModuleFactoryWithoutScript_with_data_dir(self):
675
        data_dir = "data2"
676
        factory = HubDatasetModuleFactoryWithoutScript(
677
            SAMPLE_DATASET_IDENTIFIER3, data_dir=data_dir, download_config=self.download_config
678
        )
679
        module_factory_result = factory.get_module()
680
        assert importlib.import_module(module_factory_result.module_path) is not None
681
        builder_config = module_factory_result.builder_configs_parameters.builder_configs[0]
682
        assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT)
683
        assert (
684
            builder_config.data_files is not None
685
            and len(builder_config.data_files["train"]) == 1
686
            and len(builder_config.data_files["test"]) == 1
687
        )
688
        assert all(
689
            data_dir in Path(data_file).parts
690
            for data_file in builder_config.data_files["train"] + builder_config.data_files["test"]
691
        )
692

693
    @pytest.mark.integration
694
    def test_HubDatasetModuleFactoryWithoutScript_with_metadata(self):
695
        factory = HubDatasetModuleFactoryWithoutScript(
696
            SAMPLE_DATASET_IDENTIFIER4, download_config=self.download_config
697
        )
698
        module_factory_result = factory.get_module()
699
        assert importlib.import_module(module_factory_result.module_path) is not None
700
        builder_config = module_factory_result.builder_configs_parameters.builder_configs[0]
701
        assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT)
702
        assert (
703
            builder_config.data_files is not None
704
            and len(builder_config.data_files["train"]) > 0
705
            and len(builder_config.data_files["test"]) > 0
706
        )
707
        assert any(Path(data_file).name == "metadata.jsonl" for data_file in builder_config.data_files["train"])
708
        assert any(Path(data_file).name == "metadata.jsonl" for data_file in builder_config.data_files["test"])
709

710
        factory = HubDatasetModuleFactoryWithoutScript(
711
            SAMPLE_DATASET_IDENTIFIER5, download_config=self.download_config
712
        )
713
        module_factory_result = factory.get_module()
714
        assert importlib.import_module(module_factory_result.module_path) is not None
715
        builder_config = module_factory_result.builder_configs_parameters.builder_configs[0]
716
        assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT)
717
        assert (
718
            builder_config.data_files is not None
719
            and len(builder_config.data_files) == 1
720
            and len(builder_config.data_files["train"]) > 0
721
        )
722
        assert any(Path(data_file).name == "metadata.jsonl" for data_file in builder_config.data_files["train"])
723

724
    @pytest.mark.integration
725
    def test_HubDatasetModuleFactoryWithoutScript_with_one_default_config_in_metadata(self):
726
        factory = HubDatasetModuleFactoryWithoutScript(
727
            SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA,
728
            download_config=self.download_config,
729
        )
730
        module_factory_result = factory.get_module()
731
        assert importlib.import_module(module_factory_result.module_path) is not None
732
        assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT)
733

734
        module_metadata_configs = module_factory_result.builder_configs_parameters.metadata_configs
735
        assert module_metadata_configs is not None
736
        assert len(module_metadata_configs) == 1
737
        assert next(iter(module_metadata_configs)) == "custom"
738
        assert "drop_labels" in next(iter(module_metadata_configs.values()))
739
        assert next(iter(module_metadata_configs.values()))["drop_labels"] is True
740

741
        module_builder_configs = module_factory_result.builder_configs_parameters.builder_configs
742
        assert module_builder_configs is not None
743
        assert len(module_builder_configs) == 1
744
        assert isinstance(module_builder_configs[0], AudioFolderConfig)
745
        assert module_builder_configs[0].name == "custom"
746
        assert module_builder_configs[0].data_files is not None
747
        assert isinstance(module_builder_configs[0].data_files, DataFilesPatternsDict)
748
        module_builder_configs[0]._resolve_data_files(
749
            module_factory_result.builder_kwargs["base_path"], DownloadConfig()
750
        )
751
        assert isinstance(module_builder_configs[0].data_files, DataFilesDict)
752
        assert sorted(module_builder_configs[0].data_files) == ["test", "train"]
753
        assert len(module_builder_configs[0].data_files["train"]) == 3
754
        assert len(module_builder_configs[0].data_files["test"]) == 3
755
        assert module_builder_configs[0].drop_labels is True  # parameter is passed from metadata
756

757
        # config named "default" is automatically considered to be a default config
758
        assert module_factory_result.builder_configs_parameters.default_config_name == "custom"
759

760
        # we don't pass config params to builder in builder_kwargs, they are stored in builder_configs directly
761
        assert "drop_labels" not in module_factory_result.builder_kwargs
762

763
    @pytest.mark.integration
764
    def test_HubDatasetModuleFactoryWithoutScript_with_two_configs_in_metadata(self):
765
        datasets_names = [SAMPLE_DATASET_TWO_CONFIG_IN_METADATA, SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_WITH_DEFAULT]
766
        for dataset_name in datasets_names:
767
            factory = HubDatasetModuleFactoryWithoutScript(dataset_name, download_config=self.download_config)
768
            module_factory_result = factory.get_module()
769
            assert importlib.import_module(module_factory_result.module_path) is not None
770

771
            module_metadata_configs = module_factory_result.builder_configs_parameters.metadata_configs
772
            assert module_metadata_configs is not None
773
            assert len(module_metadata_configs) == 2
774
            assert list(module_metadata_configs) == ["v1", "v2"]
775
            assert "drop_labels" in module_metadata_configs["v1"]
776
            assert module_metadata_configs["v1"]["drop_labels"] is True
777
            assert "drop_labels" in module_metadata_configs["v2"]
778
            assert module_metadata_configs["v2"]["drop_labels"] is False
779

780
            module_builder_configs = module_factory_result.builder_configs_parameters.builder_configs
781
            assert module_builder_configs is not None
782
            assert len(module_builder_configs) == 2
783
            module_builder_config_v1, module_builder_config_v2 = module_builder_configs
784
            assert module_builder_config_v1.name == "v1"
785
            assert module_builder_config_v2.name == "v2"
786
            assert isinstance(module_builder_config_v1, AudioFolderConfig)
787
            assert isinstance(module_builder_config_v2, AudioFolderConfig)
788
            assert isinstance(module_builder_config_v1.data_files, DataFilesPatternsDict)
789
            assert isinstance(module_builder_config_v2.data_files, DataFilesPatternsDict)
790
            module_builder_config_v1._resolve_data_files(
791
                module_factory_result.builder_kwargs["base_path"], DownloadConfig()
792
            )
793
            module_builder_config_v2._resolve_data_files(
794
                module_factory_result.builder_kwargs["base_path"], DownloadConfig()
795
            )
796
            assert isinstance(module_builder_config_v1.data_files, DataFilesDict)
797
            assert isinstance(module_builder_config_v2.data_files, DataFilesDict)
798
            assert sorted(module_builder_config_v1.data_files) == ["test", "train"]
799
            assert len(module_builder_config_v1.data_files["train"]) == 3
800
            assert len(module_builder_config_v1.data_files["test"]) == 3
801
            assert sorted(module_builder_config_v2.data_files) == ["test", "train"]
802
            assert len(module_builder_config_v2.data_files["train"]) == 2
803
            assert len(module_builder_config_v2.data_files["test"]) == 1
804
            assert module_builder_config_v1.drop_labels is True  # parameter is passed from metadata
805
            assert module_builder_config_v2.drop_labels is False  # parameter is passed from metadata
806
            # we don't pass config params to builder in builder_kwargs, they are stored in builder_configs directly
807
            assert "drop_labels" not in module_factory_result.builder_kwargs
808

809
            if dataset_name == SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_WITH_DEFAULT:
810
                assert module_factory_result.builder_configs_parameters.default_config_name == "v1"
811
            else:
812
                assert module_factory_result.builder_configs_parameters.default_config_name is None
813

814
    @pytest.mark.integration
815
    def test_HubDatasetModuleFactoryWithScript(self):
816
        factory = HubDatasetModuleFactoryWithScript(
817
            SAMPLE_DATASET_IDENTIFIER,
818
            download_config=self.download_config,
819
            dynamic_modules_path=self.dynamic_modules_path,
820
        )
821
        module_factory_result = factory.get_module()
822
        assert importlib.import_module(module_factory_result.module_path) is not None
823
        assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT)
824

825
    @pytest.mark.integration
826
    def test_HubDatasetModuleFactoryWithParquetExport(self):
827
        factory = HubDatasetModuleFactoryWithParquetExport(
828
            SAMPLE_DATASET_IDENTIFIER,
829
            download_config=self.download_config,
830
        )
831
        module_factory_result = factory.get_module()
832
        assert module_factory_result.module_path == "datasets.packaged_modules.parquet.parquet"
833
        assert module_factory_result.builder_configs_parameters.builder_configs
834
        assert isinstance(module_factory_result.builder_configs_parameters.builder_configs[0], ParquetConfig)
835
        module_factory_result.builder_configs_parameters.builder_configs[0]._resolve_data_files(
836
            base_path="", download_config=self.download_config
837
        )
838
        assert module_factory_result.builder_configs_parameters.builder_configs[0].data_files == {
839
            "train": [
840
                "hf://datasets/hf-internal-testing/dataset_with_script@da4ed81df5a1bcd916043c827b75994de8ef7eda/default/train/0000.parquet"
841
            ],
842
            "validation": [
843
                "hf://datasets/hf-internal-testing/dataset_with_script@da4ed81df5a1bcd916043c827b75994de8ef7eda/default/validation/0000.parquet"
844
            ],
845
        }
846

847
    @pytest.mark.integration
848
    def test_HubDatasetModuleFactoryWithParquetExport_errors_on_wrong_sha(self):
849
        factory = HubDatasetModuleFactoryWithParquetExport(
850
            SAMPLE_DATASET_IDENTIFIER,
851
            download_config=self.download_config,
852
            revision="1a21ac5846fc3f36ad5f128740c58932d3d7806f",
853
        )
854
        factory.get_module()
855
        factory = HubDatasetModuleFactoryWithParquetExport(
856
            SAMPLE_DATASET_IDENTIFIER,
857
            download_config=self.download_config,
858
            revision="wrong_sha",
859
        )
860
        with self.assertRaises(_datasets_server.DatasetsServerError):
861
            factory.get_module()
862

863
    @pytest.mark.integration
864
    def test_CachedDatasetModuleFactory(self):
865
        name = SAMPLE_DATASET_IDENTIFIER2
866
        load_dataset_builder(name, cache_dir=self.cache_dir).download_and_prepare()
867
        for offline_mode in OfflineSimulationMode:
868
            with offline(offline_mode):
869
                factory = CachedDatasetModuleFactory(
870
                    name,
871
                    cache_dir=self.cache_dir,
872
                )
873
                module_factory_result = factory.get_module()
874
                assert importlib.import_module(module_factory_result.module_path) is not None
875

876
    def test_CachedDatasetModuleFactory_with_script(self):
877
        path = os.path.join(self._dataset_loading_script_dir, f"{DATASET_LOADING_SCRIPT_NAME}.py")
878
        factory = LocalDatasetModuleFactoryWithScript(
879
            path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
880
        )
881
        module_factory_result = factory.get_module()
882
        for offline_mode in OfflineSimulationMode:
883
            with offline(offline_mode):
884
                factory = CachedDatasetModuleFactory(
885
                    DATASET_LOADING_SCRIPT_NAME,
886
                    dynamic_modules_path=self.dynamic_modules_path,
887
                )
888
                module_factory_result = factory.get_module()
889
                assert importlib.import_module(module_factory_result.module_path) is not None
890

891
    @pytest.mark.filterwarnings("ignore:LocalMetricModuleFactory is deprecated:FutureWarning")
892
    @pytest.mark.filterwarnings("ignore:CachedMetricModuleFactory is deprecated:FutureWarning")
893
    def test_CachedMetricModuleFactory(self):
894
        path = os.path.join(self._metric_loading_script_dir, f"{METRIC_LOADING_SCRIPT_NAME}.py")
895
        factory = LocalMetricModuleFactory(
896
            path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
897
        )
898
        module_factory_result = factory.get_module()
899
        for offline_mode in OfflineSimulationMode:
900
            with offline(offline_mode):
901
                factory = CachedMetricModuleFactory(
902
                    METRIC_LOADING_SCRIPT_NAME,
903
                    dynamic_modules_path=self.dynamic_modules_path,
904
                )
905
                module_factory_result = factory.get_module()
906
                assert importlib.import_module(module_factory_result.module_path) is not None
907

908

909
@pytest.mark.parametrize(
910
    "factory_class",
911
    [
912
        CachedDatasetModuleFactory,
913
        CachedMetricModuleFactory,
914
        GithubMetricModuleFactory,
915
        HubDatasetModuleFactoryWithoutScript,
916
        HubDatasetModuleFactoryWithScript,
917
        LocalDatasetModuleFactoryWithoutScript,
918
        LocalDatasetModuleFactoryWithScript,
919
        LocalMetricModuleFactory,
920
        PackagedDatasetModuleFactory,
921
    ],
922
)
923
def test_module_factories(factory_class):
924
    name = "dummy_name"
925
    factory = factory_class(name)
926
    assert factory.name == name
927

928

929
@pytest.mark.integration
930
class LoadTest(TestCase):
931
    @pytest.fixture(autouse=True)
932
    def inject_fixtures(self, caplog):
933
        self._caplog = caplog
934

935
    def setUp(self):
936
        self.hf_modules_cache = tempfile.mkdtemp()
937
        self.cache_dir = tempfile.mkdtemp()
938
        self.dynamic_modules_path = datasets.load.init_dynamic_modules(
939
            name="test_datasets_modules2", hf_modules_cache=self.hf_modules_cache
940
        )
941

942
    def tearDown(self):
943
        shutil.rmtree(self.hf_modules_cache)
944
        shutil.rmtree(self.cache_dir)
945

946
    def _dummy_module_dir(self, modules_dir, dummy_module_name, dummy_code):
947
        assert dummy_module_name.startswith("__")
948
        module_dir = os.path.join(modules_dir, dummy_module_name)
949
        os.makedirs(module_dir, exist_ok=True)
950
        module_path = os.path.join(module_dir, dummy_module_name + ".py")
951
        with open(module_path, "w") as f:
952
            f.write(dummy_code)
953
        return module_dir
954

955
    def test_dataset_module_factory(self):
956
        with tempfile.TemporaryDirectory() as tmp_dir:
957
            # prepare module from directory path
958
            dummy_code = "MY_DUMMY_VARIABLE = 'hello there'"
959
            module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name1__", dummy_code)
960
            dataset_module = datasets.load.dataset_module_factory(
961
                module_dir, dynamic_modules_path=self.dynamic_modules_path
962
            )
963
            dummy_module = importlib.import_module(dataset_module.module_path)
964
            self.assertEqual(dummy_module.MY_DUMMY_VARIABLE, "hello there")
965
            self.assertEqual(dataset_module.hash, sha256(dummy_code.encode("utf-8")).hexdigest())
966
            # prepare module from file path + check resolved_file_path
967
            dummy_code = "MY_DUMMY_VARIABLE = 'general kenobi'"
968
            module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name1__", dummy_code)
969
            module_path = os.path.join(module_dir, "__dummy_module_name1__.py")
970
            dataset_module = datasets.load.dataset_module_factory(
971
                module_path, dynamic_modules_path=self.dynamic_modules_path
972
            )
973
            dummy_module = importlib.import_module(dataset_module.module_path)
974
            self.assertEqual(dummy_module.MY_DUMMY_VARIABLE, "general kenobi")
975
            self.assertEqual(dataset_module.hash, sha256(dummy_code.encode("utf-8")).hexdigest())
976
            # missing module
977
            for offline_simulation_mode in list(OfflineSimulationMode):
978
                with offline(offline_simulation_mode):
979
                    with self.assertRaises(
980
                        (DatasetNotFoundError, ConnectionError, requests.exceptions.ConnectionError)
981
                    ):
982
                        datasets.load.dataset_module_factory(
983
                            "__missing_dummy_module_name__", dynamic_modules_path=self.dynamic_modules_path
984
                        )
985

986
    @pytest.mark.integration
987
    def test_offline_dataset_module_factory(self):
988
        repo_id = SAMPLE_DATASET_IDENTIFIER2
989
        builder = load_dataset_builder(repo_id, cache_dir=self.cache_dir)
990
        builder.download_and_prepare()
991
        for offline_simulation_mode in list(OfflineSimulationMode):
992
            with offline(offline_simulation_mode):
993
                self._caplog.clear()
994
                # allow provide the repo id without an explicit path to remote or local actual file
995
                dataset_module = datasets.load.dataset_module_factory(repo_id, cache_dir=self.cache_dir)
996
                self.assertEqual(dataset_module.module_path, "datasets.packaged_modules.cache.cache")
997
                self.assertIn("Using the latest cached version of the dataset", self._caplog.text)
998

999
    def test_offline_dataset_module_factory_with_script(self):
1000
        with tempfile.TemporaryDirectory() as tmp_dir:
1001
            dummy_code = "MY_DUMMY_VARIABLE = 'hello there'"
1002
            module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name2__", dummy_code)
1003
            dataset_module_1 = datasets.load.dataset_module_factory(
1004
                module_dir, dynamic_modules_path=self.dynamic_modules_path
1005
            )
1006
            time.sleep(0.1)  # make sure there's a difference in the OS update time of the python file
1007
            dummy_code = "MY_DUMMY_VARIABLE = 'general kenobi'"
1008
            module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name2__", dummy_code)
1009
            dataset_module_2 = datasets.load.dataset_module_factory(
1010
                module_dir, dynamic_modules_path=self.dynamic_modules_path
1011
            )
1012
        for offline_simulation_mode in list(OfflineSimulationMode):
1013
            with offline(offline_simulation_mode):
1014
                self._caplog.clear()
1015
                # allow provide the module name without an explicit path to remote or local actual file
1016
                dataset_module_3 = datasets.load.dataset_module_factory(
1017
                    "__dummy_module_name2__", dynamic_modules_path=self.dynamic_modules_path
1018
                )
1019
                # it loads the most recent version of the module
1020
                self.assertEqual(dataset_module_2.module_path, dataset_module_3.module_path)
1021
                self.assertNotEqual(dataset_module_1.module_path, dataset_module_3.module_path)
1022
                self.assertIn("Using the latest cached version of the module", self._caplog.text)
1023

1024
    def test_load_dataset_from_hub(self):
1025
        with self.assertRaises(DatasetNotFoundError) as context:
1026
            datasets.load_dataset("_dummy")
1027
        self.assertIn(
1028
            "Dataset '_dummy' doesn't exist on the Hub",
1029
            str(context.exception),
1030
        )
1031
        with self.assertRaises(DatasetNotFoundError) as context:
1032
            datasets.load_dataset("_dummy", revision="0.0.0")
1033
        self.assertIn(
1034
            "Dataset '_dummy' doesn't exist on the Hub",
1035
            str(context.exception),
1036
        )
1037
        self.assertIn(
1038
            "at revision '0.0.0'",
1039
            str(context.exception),
1040
        )
1041
        for offline_simulation_mode in list(OfflineSimulationMode):
1042
            with offline(offline_simulation_mode):
1043
                with self.assertRaises(ConnectionError) as context:
1044
                    datasets.load_dataset("_dummy")
1045
                if offline_simulation_mode != OfflineSimulationMode.HF_DATASETS_OFFLINE_SET_TO_1:
1046
                    self.assertIn(
1047
                        "Couldn't reach '_dummy' on the Hub",
1048
                        str(context.exception),
1049
                    )
1050

1051
    def test_load_dataset_namespace(self):
1052
        with self.assertRaises(DatasetNotFoundError) as context:
1053
            datasets.load_dataset("hf-internal-testing/_dummy")
1054
        self.assertIn(
1055
            "hf-internal-testing/_dummy",
1056
            str(context.exception),
1057
        )
1058
        for offline_simulation_mode in list(OfflineSimulationMode):
1059
            with offline(offline_simulation_mode):
1060
                with self.assertRaises(ConnectionError) as context:
1061
                    datasets.load_dataset("hf-internal-testing/_dummy")
1062
                self.assertIn("hf-internal-testing/_dummy", str(context.exception), msg=offline_simulation_mode)
1063

1064

1065
@pytest.mark.integration
1066
def test_load_dataset_builder_with_metadata():
1067
    builder = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER4)
1068
    assert isinstance(builder, ImageFolder)
1069
    assert builder.config.name == "default"
1070
    assert builder.config.data_files is not None
1071
    assert builder.config.drop_metadata is None
1072
    with pytest.raises(ValueError):
1073
        builder = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER4, "non-existing-config")
1074

1075

1076
@pytest.mark.integration
1077
def test_load_dataset_builder_config_kwargs_passed_as_arguments():
1078
    builder_default = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER4)
1079
    builder_custom = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER4, drop_metadata=True)
1080
    assert builder_custom.config.drop_metadata != builder_default.config.drop_metadata
1081
    assert builder_custom.config.drop_metadata is True
1082

1083

1084
@pytest.mark.integration
1085
def test_load_dataset_builder_with_two_configs_in_metadata():
1086
    builder = datasets.load_dataset_builder(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA, "v1")
1087
    assert isinstance(builder, AudioFolder)
1088
    assert builder.config.name == "v1"
1089
    assert builder.config.data_files is not None
1090
    with pytest.raises(ValueError):
1091
        datasets.load_dataset_builder(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA)
1092
    with pytest.raises(ValueError):
1093
        datasets.load_dataset_builder(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA, "non-existing-config")
1094

1095

1096
@pytest.mark.parametrize("serializer", [pickle, dill])
1097
def test_load_dataset_builder_with_metadata_configs_pickable(serializer):
1098
    builder = datasets.load_dataset_builder(SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA)
1099
    builder_unpickled = serializer.loads(serializer.dumps(builder))
1100
    assert builder.BUILDER_CONFIGS == builder_unpickled.BUILDER_CONFIGS
1101
    assert list(builder_unpickled.builder_configs) == ["custom"]
1102
    assert isinstance(builder_unpickled.builder_configs["custom"], AudioFolderConfig)
1103

1104
    builder2 = datasets.load_dataset_builder(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA, "v1")
1105
    builder2_unpickled = serializer.loads(serializer.dumps(builder2))
1106
    assert builder2.BUILDER_CONFIGS == builder2_unpickled.BUILDER_CONFIGS != builder_unpickled.BUILDER_CONFIGS
1107
    assert list(builder2_unpickled.builder_configs) == ["v1", "v2"]
1108
    assert isinstance(builder2_unpickled.builder_configs["v1"], AudioFolderConfig)
1109
    assert isinstance(builder2_unpickled.builder_configs["v2"], AudioFolderConfig)
1110

1111

1112
def test_load_dataset_builder_for_absolute_script_dir(dataset_loading_script_dir, data_dir):
1113
    builder = datasets.load_dataset_builder(dataset_loading_script_dir, data_dir=data_dir)
1114
    assert isinstance(builder, DatasetBuilder)
1115
    assert builder.name == DATASET_LOADING_SCRIPT_NAME
1116
    assert builder.dataset_name == DATASET_LOADING_SCRIPT_NAME
1117
    assert builder.info.features == Features({"text": Value("string")})
1118

1119

1120
def test_load_dataset_builder_for_relative_script_dir(dataset_loading_script_dir, data_dir):
1121
    with set_current_working_directory_to_temp_dir():
1122
        relative_script_dir = DATASET_LOADING_SCRIPT_NAME
1123
        shutil.copytree(dataset_loading_script_dir, relative_script_dir)
1124
        builder = datasets.load_dataset_builder(relative_script_dir, data_dir=data_dir)
1125
        assert isinstance(builder, DatasetBuilder)
1126
        assert builder.name == DATASET_LOADING_SCRIPT_NAME
1127
        assert builder.dataset_name == DATASET_LOADING_SCRIPT_NAME
1128
        assert builder.info.features == Features({"text": Value("string")})
1129

1130

1131
def test_load_dataset_builder_for_script_path(dataset_loading_script_dir, data_dir):
1132
    builder = datasets.load_dataset_builder(
1133
        os.path.join(dataset_loading_script_dir, DATASET_LOADING_SCRIPT_NAME + ".py"), data_dir=data_dir
1134
    )
1135
    assert isinstance(builder, DatasetBuilder)
1136
    assert builder.name == DATASET_LOADING_SCRIPT_NAME
1137
    assert builder.dataset_name == DATASET_LOADING_SCRIPT_NAME
1138
    assert builder.info.features == Features({"text": Value("string")})
1139

1140

1141
def test_load_dataset_builder_for_absolute_data_dir(complex_data_dir):
1142
    builder = datasets.load_dataset_builder(complex_data_dir)
1143
    assert isinstance(builder, DatasetBuilder)
1144
    assert builder.name == "text"
1145
    assert builder.dataset_name == Path(complex_data_dir).name
1146
    assert builder.config.name == "default"
1147
    assert isinstance(builder.config.data_files, DataFilesDict)
1148
    assert len(builder.config.data_files["train"]) > 0
1149
    assert len(builder.config.data_files["test"]) > 0
1150

1151

1152
def test_load_dataset_builder_for_relative_data_dir(complex_data_dir):
1153
    with set_current_working_directory_to_temp_dir():
1154
        relative_data_dir = "relative_data_dir"
1155
        shutil.copytree(complex_data_dir, relative_data_dir)
1156
        builder = datasets.load_dataset_builder(relative_data_dir)
1157
        assert isinstance(builder, DatasetBuilder)
1158
        assert builder.name == "text"
1159
        assert builder.dataset_name == relative_data_dir
1160
        assert builder.config.name == "default"
1161
        assert isinstance(builder.config.data_files, DataFilesDict)
1162
        assert len(builder.config.data_files["train"]) > 0
1163
        assert len(builder.config.data_files["test"]) > 0
1164

1165

1166
@pytest.mark.integration
1167
def test_load_dataset_builder_for_community_dataset_with_script():
1168
    builder = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER)
1169
    assert isinstance(builder, DatasetBuilder)
1170
    assert builder.name == "parquet"
1171
    assert builder.dataset_name == SAMPLE_DATASET_IDENTIFIER.split("/")[-1]
1172
    assert builder.config.name == "default"
1173
    assert builder.info.features == Features({"text": Value("string")})
1174
    namespace = SAMPLE_DATASET_IDENTIFIER[: SAMPLE_DATASET_IDENTIFIER.index("/")]
1175
    assert builder._relative_data_dir().startswith(namespace)
1176
    assert builder.__module__.startswith("datasets.")
1177

1178

1179
@pytest.mark.integration
1180
def test_load_dataset_builder_for_community_dataset_with_script_no_parquet_export():
1181
    with patch.object(config, "USE_PARQUET_EXPORT", False):
1182
        builder = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER)
1183
    assert isinstance(builder, DatasetBuilder)
1184
    assert builder.name == SAMPLE_DATASET_IDENTIFIER.split("/")[-1]
1185
    assert builder.dataset_name == SAMPLE_DATASET_IDENTIFIER.split("/")[-1]
1186
    assert builder.config.name == "default"
1187
    assert builder.info.features == Features({"text": Value("string")})
1188
    namespace = SAMPLE_DATASET_IDENTIFIER[: SAMPLE_DATASET_IDENTIFIER.index("/")]
1189
    assert builder._relative_data_dir().startswith(namespace)
1190
    assert SAMPLE_DATASET_IDENTIFIER.replace("/", "--") in builder.__module__
1191

1192

1193
@pytest.mark.integration
1194
def test_load_dataset_builder_use_parquet_export_if_dont_trust_remote_code_keeps_features():
1195
    dataset_name = "food101"
1196
    builder = datasets.load_dataset_builder(dataset_name, trust_remote_code=False)
1197
    assert isinstance(builder, DatasetBuilder)
1198
    assert builder.name == "parquet"
1199
    assert builder.dataset_name == dataset_name
1200
    assert builder.config.name == "default"
1201
    assert list(builder.info.features) == ["image", "label"]
1202
    assert builder.info.features["image"] == Image()
1203

1204

1205
@pytest.mark.integration
1206
def test_load_dataset_builder_for_community_dataset_without_script():
1207
    builder = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER2)
1208
    assert isinstance(builder, DatasetBuilder)
1209
    assert builder.name == "text"
1210
    assert builder.dataset_name == SAMPLE_DATASET_IDENTIFIER2.split("/")[-1]
1211
    assert builder.config.name == "default"
1212
    assert isinstance(builder.config.data_files, DataFilesDict)
1213
    assert len(builder.config.data_files["train"]) > 0
1214
    assert len(builder.config.data_files["test"]) > 0
1215

1216

1217
def test_load_dataset_builder_fail():
1218
    with pytest.raises(DatasetNotFoundError):
1219
        datasets.load_dataset_builder("blabla")
1220

1221

1222
@pytest.mark.parametrize("keep_in_memory", [False, True])
1223
def test_load_dataset_local_script(dataset_loading_script_dir, data_dir, keep_in_memory, caplog):
1224
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
1225
        dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, keep_in_memory=keep_in_memory)
1226
    assert isinstance(dataset, DatasetDict)
1227
    assert all(isinstance(d, Dataset) for d in dataset.values())
1228
    assert len(dataset) == 2
1229
    assert isinstance(next(iter(dataset["train"])), dict)
1230

1231

1232
def test_load_dataset_cached_local_script(dataset_loading_script_dir, data_dir, caplog):
1233
    dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir)
1234
    assert isinstance(dataset, DatasetDict)
1235
    assert all(isinstance(d, Dataset) for d in dataset.values())
1236
    assert len(dataset) == 2
1237
    assert isinstance(next(iter(dataset["train"])), dict)
1238
    for offline_simulation_mode in list(OfflineSimulationMode):
1239
        with offline(offline_simulation_mode):
1240
            caplog.clear()
1241
            # Load dataset from cache
1242
            dataset = datasets.load_dataset(DATASET_LOADING_SCRIPT_NAME, data_dir=data_dir)
1243
            assert len(dataset) == 2
1244
            assert "Using the latest cached version of the module" in caplog.text
1245
            assert isinstance(next(iter(dataset["train"])), dict)
1246
    with pytest.raises(DatasetNotFoundError) as exc_info:
1247
        datasets.load_dataset(SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST)
1248
    assert f"Dataset '{SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST}' doesn't exist on the Hub" in str(exc_info.value)
1249

1250

1251
@pytest.mark.integration
1252
@pytest.mark.parametrize("stream_from_cache, ", [False, True])
1253
def test_load_dataset_cached_from_hub(stream_from_cache, caplog):
1254
    dataset = load_dataset(SAMPLE_DATASET_IDENTIFIER3)
1255
    assert isinstance(dataset, DatasetDict)
1256
    assert all(isinstance(d, Dataset) for d in dataset.values())
1257
    assert len(dataset) == 2
1258
    assert isinstance(next(iter(dataset["train"])), dict)
1259
    for offline_simulation_mode in list(OfflineSimulationMode):
1260
        with offline(offline_simulation_mode):
1261
            caplog.clear()
1262
            # Load dataset from cache
1263
            dataset = datasets.load_dataset(SAMPLE_DATASET_IDENTIFIER3, streaming=stream_from_cache)
1264
            assert len(dataset) == 2
1265
            assert "Using the latest cached version of the dataset" in caplog.text
1266
            assert isinstance(next(iter(dataset["train"])), dict)
1267
    with pytest.raises(DatasetNotFoundError) as exc_info:
1268
        datasets.load_dataset(SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST)
1269
    assert f"Dataset '{SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST}' doesn't exist on the Hub" in str(exc_info.value)
1270

1271

1272
def test_load_dataset_streaming(dataset_loading_script_dir, data_dir):
1273
    dataset = load_dataset(dataset_loading_script_dir, streaming=True, data_dir=data_dir)
1274
    assert isinstance(dataset, IterableDatasetDict)
1275
    assert all(isinstance(d, IterableDataset) for d in dataset.values())
1276
    assert len(dataset) == 2
1277
    assert isinstance(next(iter(dataset["train"])), dict)
1278

1279

1280
def test_load_dataset_streaming_gz_json(jsonl_gz_path):
1281
    data_files = jsonl_gz_path
1282
    ds = load_dataset("json", split="train", data_files=data_files, streaming=True)
1283
    assert isinstance(ds, IterableDataset)
1284
    ds_item = next(iter(ds))
1285
    assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}
1286

1287

1288
@pytest.mark.integration
1289
@pytest.mark.parametrize(
1290
    "path", ["sample.jsonl", "sample.jsonl.gz", "sample.tar", "sample.jsonl.xz", "sample.zip", "sample.jsonl.zst"]
1291
)
1292
def test_load_dataset_streaming_compressed_files(path):
1293
    repo_id = "hf-internal-testing/compressed_files"
1294
    data_files = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{path}"
1295
    if data_files[-3:] in ("zip", "tar"):  # we need to glob "*" inside archives
1296
        data_files = data_files[-3:] + "://*::" + data_files
1297
        return  # TODO(QL, albert): support re-add support for ZIP and TAR archives streaming
1298
    ds = load_dataset("json", split="train", data_files=data_files, streaming=True)
1299
    assert isinstance(ds, IterableDataset)
1300
    ds_item = next(iter(ds))
1301
    assert ds_item == {
1302
        "tokens": ["Ministeri", "de", "Justícia", "d'Espanya"],
1303
        "ner_tags": [1, 2, 2, 2],
1304
        "langs": ["ca", "ca", "ca", "ca"],
1305
        "spans": ["PER: Ministeri de Justícia d'Espanya"],
1306
    }
1307

1308

1309
@pytest.mark.parametrize("path_extension", ["csv", "csv.bz2"])
1310
@pytest.mark.parametrize("streaming", [False, True])
1311
def test_load_dataset_streaming_csv(path_extension, streaming, csv_path, bz2_csv_path):
1312
    paths = {"csv": csv_path, "csv.bz2": bz2_csv_path}
1313
    data_files = str(paths[path_extension])
1314
    features = Features({"col_1": Value("string"), "col_2": Value("int32"), "col_3": Value("float32")})
1315
    ds = load_dataset("csv", split="train", data_files=data_files, features=features, streaming=streaming)
1316
    assert isinstance(ds, IterableDataset if streaming else Dataset)
1317
    ds_item = next(iter(ds))
1318
    assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}
1319

1320

1321
@pytest.mark.parametrize("streaming", [False, True])
1322
@pytest.mark.parametrize("data_file", ["zip_csv_path", "zip_csv_with_dir_path", "csv_path"])
1323
def test_load_dataset_zip_csv(data_file, streaming, zip_csv_path, zip_csv_with_dir_path, csv_path):
1324
    data_file_paths = {
1325
        "zip_csv_path": zip_csv_path,
1326
        "zip_csv_with_dir_path": zip_csv_with_dir_path,
1327
        "csv_path": csv_path,
1328
    }
1329
    data_files = str(data_file_paths[data_file])
1330
    expected_size = 8 if data_file.startswith("zip") else 4
1331
    features = Features({"col_1": Value("string"), "col_2": Value("int32"), "col_3": Value("float32")})
1332
    ds = load_dataset("csv", split="train", data_files=data_files, features=features, streaming=streaming)
1333
    if streaming:
1334
        ds_item_counter = 0
1335
        for ds_item in ds:
1336
            if ds_item_counter == 0:
1337
                assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}
1338
            ds_item_counter += 1
1339
        assert ds_item_counter == expected_size
1340
    else:
1341
        assert ds.shape[0] == expected_size
1342
        ds_item = next(iter(ds))
1343
        assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}
1344

1345

1346
@pytest.mark.parametrize("streaming", [False, True])
1347
@pytest.mark.parametrize("data_file", ["zip_jsonl_path", "zip_jsonl_with_dir_path", "jsonl_path"])
1348
def test_load_dataset_zip_jsonl(data_file, streaming, zip_jsonl_path, zip_jsonl_with_dir_path, jsonl_path):
1349
    data_file_paths = {
1350
        "zip_jsonl_path": zip_jsonl_path,
1351
        "zip_jsonl_with_dir_path": zip_jsonl_with_dir_path,
1352
        "jsonl_path": jsonl_path,
1353
    }
1354
    data_files = str(data_file_paths[data_file])
1355
    expected_size = 8 if data_file.startswith("zip") else 4
1356
    features = Features({"col_1": Value("string"), "col_2": Value("int32"), "col_3": Value("float32")})
1357
    ds = load_dataset("json", split="train", data_files=data_files, features=features, streaming=streaming)
1358
    if streaming:
1359
        ds_item_counter = 0
1360
        for ds_item in ds:
1361
            if ds_item_counter == 0:
1362
                assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}
1363
            ds_item_counter += 1
1364
        assert ds_item_counter == expected_size
1365
    else:
1366
        assert ds.shape[0] == expected_size
1367
        ds_item = next(iter(ds))
1368
        assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}
1369

1370

1371
@pytest.mark.parametrize("streaming", [False, True])
1372
@pytest.mark.parametrize("data_file", ["zip_text_path", "zip_text_with_dir_path", "text_path"])
1373
def test_load_dataset_zip_text(data_file, streaming, zip_text_path, zip_text_with_dir_path, text_path):
1374
    data_file_paths = {
1375
        "zip_text_path": zip_text_path,
1376
        "zip_text_with_dir_path": zip_text_with_dir_path,
1377
        "text_path": text_path,
1378
    }
1379
    data_files = str(data_file_paths[data_file])
1380
    expected_size = 8 if data_file.startswith("zip") else 4
1381
    ds = load_dataset("text", split="train", data_files=data_files, streaming=streaming)
1382
    if streaming:
1383
        ds_item_counter = 0
1384
        for ds_item in ds:
1385
            if ds_item_counter == 0:
1386
                assert ds_item == {"text": "0"}
1387
            ds_item_counter += 1
1388
        assert ds_item_counter == expected_size
1389
    else:
1390
        assert ds.shape[0] == expected_size
1391
        ds_item = next(iter(ds))
1392
        assert ds_item == {"text": "0"}
1393

1394

1395
@pytest.mark.parametrize("streaming", [False, True])
1396
def test_load_dataset_arrow(streaming, data_dir_with_arrow):
1397
    ds = load_dataset("arrow", split="train", data_dir=data_dir_with_arrow, streaming=streaming)
1398
    expected_size = 10
1399
    if streaming:
1400
        ds_item_counter = 0
1401
        for ds_item in ds:
1402
            if ds_item_counter == 0:
1403
                assert ds_item == {"col_1": "foo"}
1404
            ds_item_counter += 1
1405
        assert ds_item_counter == 10
1406
    else:
1407
        assert ds.num_rows == 10
1408
        assert ds.shape[0] == expected_size
1409
        ds_item = next(iter(ds))
1410
        assert ds_item == {"col_1": "foo"}
1411

1412

1413
def test_load_dataset_text_with_unicode_new_lines(text_path_with_unicode_new_lines):
1414
    data_files = str(text_path_with_unicode_new_lines)
1415
    ds = load_dataset("text", split="train", data_files=data_files)
1416
    assert ds.num_rows == 3
1417

1418

1419
def test_load_dataset_with_unsupported_extensions(text_dir_with_unsupported_extension):
1420
    data_files = str(text_dir_with_unsupported_extension)
1421
    ds = load_dataset("text", split="train", data_files=data_files)
1422
    assert ds.num_rows == 4
1423

1424

1425
@pytest.mark.integration
1426
def test_loading_from_the_datasets_hub():
1427
    with tempfile.TemporaryDirectory() as tmp_dir:
1428
        with load_dataset(SAMPLE_DATASET_IDENTIFIER, cache_dir=tmp_dir) as dataset:
1429
            assert len(dataset["train"]) == 2
1430
            assert len(dataset["validation"]) == 3
1431

1432

1433
@pytest.mark.integration
1434
def test_loading_from_the_datasets_hub_with_token():
1435
    true_request = requests.Session().request
1436

1437
    def assert_auth(method, url, *args, headers, **kwargs):
1438
        assert headers["authorization"] == "Bearer foo"
1439
        return true_request(method, url, *args, headers=headers, **kwargs)
1440

1441
    with patch("requests.Session.request") as mock_request:
1442
        mock_request.side_effect = assert_auth
1443
        with tempfile.TemporaryDirectory() as tmp_dir:
1444
            with offline():
1445
                with pytest.raises((ConnectionError, requests.exceptions.ConnectionError)):
1446
                    load_dataset(SAMPLE_NOT_EXISTING_DATASET_IDENTIFIER, cache_dir=tmp_dir, token="foo")
1447
        mock_request.assert_called()
1448

1449

1450
@pytest.mark.integration
1451
def test_load_streaming_private_dataset(hf_token, hf_private_dataset_repo_txt_data):
1452
    ds = load_dataset(hf_private_dataset_repo_txt_data, streaming=True, token=hf_token)
1453
    assert next(iter(ds)) is not None
1454

1455

1456
@pytest.mark.integration
1457
def test_load_dataset_builder_private_dataset(hf_token, hf_private_dataset_repo_txt_data):
1458
    builder = load_dataset_builder(hf_private_dataset_repo_txt_data, token=hf_token)
1459
    assert isinstance(builder, DatasetBuilder)
1460

1461

1462
@pytest.mark.integration
1463
def test_load_streaming_private_dataset_with_zipped_data(hf_token, hf_private_dataset_repo_zipped_txt_data):
1464
    ds = load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True, token=hf_token)
1465
    assert next(iter(ds)) is not None
1466

1467

1468
@pytest.mark.integration
1469
def test_load_dataset_config_kwargs_passed_as_arguments():
1470
    ds_default = load_dataset(SAMPLE_DATASET_IDENTIFIER4)
1471
    ds_custom = load_dataset(SAMPLE_DATASET_IDENTIFIER4, drop_metadata=True)
1472
    assert list(ds_default["train"].features) == ["image", "caption"]
1473
    assert list(ds_custom["train"].features) == ["image"]
1474

1475

1476
@require_sndfile
1477
@pytest.mark.integration
1478
def test_load_hub_dataset_without_script_with_single_config_in_metadata():
1479
    # load the same dataset but with no configurations (=with default parameters)
1480
    ds = load_dataset(SAMPLE_DATASET_NO_CONFIGS_IN_METADATA)
1481
    assert list(ds["train"].features) == ["audio", "label"]  # assert label feature is here as expected by default
1482
    assert len(ds["train"]) == 5 and len(ds["test"]) == 4
1483

1484
    ds2 = load_dataset(SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA)  # single config -> no need to specify it
1485
    assert list(ds2["train"].features) == ["audio"]  # assert param `drop_labels=True` from metadata is passed
1486
    assert len(ds2["train"]) == 3 and len(ds2["test"]) == 3
1487

1488
    ds3 = load_dataset(SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA, "custom")
1489
    assert list(ds3["train"].features) == ["audio"]  # assert param `drop_labels=True` from metadata is passed
1490
    assert len(ds3["train"]) == 3 and len(ds3["test"]) == 3
1491

1492
    with pytest.raises(ValueError):
1493
        # no config named "default"
1494
        _ = load_dataset(SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA, "default")
1495

1496

1497
@require_sndfile
1498
@pytest.mark.integration
1499
def test_load_hub_dataset_without_script_with_two_config_in_metadata():
1500
    ds = load_dataset(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA, "v1")
1501
    assert list(ds["train"].features) == ["audio"]  # assert param `drop_labels=True` from metadata is passed
1502
    assert len(ds["train"]) == 3 and len(ds["test"]) == 3
1503

1504
    ds2 = load_dataset(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA, "v2")
1505
    assert list(ds2["train"].features) == [
1506
        "audio",
1507
        "label",
1508
    ]  # assert param `drop_labels=False` from metadata is passed
1509
    assert len(ds2["train"]) == 2 and len(ds2["test"]) == 1
1510

1511
    with pytest.raises(ValueError):
1512
        # config is required but not specified
1513
        _ = load_dataset(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA)
1514

1515
    with pytest.raises(ValueError):
1516
        # no config named "default"
1517
        _ = load_dataset(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA, "default")
1518

1519
    ds_with_default = load_dataset(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_WITH_DEFAULT)
1520
    # it's a dataset with the same data but "v1" config is marked as a default one
1521
    assert list(ds_with_default["train"].features) == list(ds["train"].features)
1522
    assert len(ds_with_default["train"]) == len(ds["train"]) and len(ds_with_default["test"]) == len(ds["test"])
1523

1524

1525
@require_sndfile
1526
@pytest.mark.integration
1527
def test_load_hub_dataset_without_script_with_metadata_config_in_parallel():
1528
    # assert it doesn't fail (pickling of dynamically created class works)
1529
    ds = load_dataset(SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA, num_proc=2)
1530
    assert "label" not in ds["train"].features  # assert param `drop_labels=True` from metadata is passed
1531
    assert len(ds["train"]) == 3 and len(ds["test"]) == 3
1532

1533
    ds = load_dataset(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA, "v1", num_proc=2)
1534
    assert "label" not in ds["train"].features  # assert param `drop_labels=True` from metadata is passed
1535
    assert len(ds["train"]) == 3 and len(ds["test"]) == 3
1536

1537
    ds = load_dataset(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA, "v2", num_proc=2)
1538
    assert "label" in ds["train"].features
1539
    assert len(ds["train"]) == 2 and len(ds["test"]) == 1
1540

1541

1542
@require_pil
1543
@pytest.mark.integration
1544
@pytest.mark.parametrize("streaming", [True])
1545
def test_load_dataset_private_zipped_images(hf_private_dataset_repo_zipped_img_data, hf_token, streaming):
1546
    ds = load_dataset(hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, token=hf_token)
1547
    assert isinstance(ds, IterableDataset if streaming else Dataset)
1548
    ds_items = list(ds)
1549
    assert len(ds_items) == 2
1550

1551

1552
def test_load_dataset_then_move_then_reload(dataset_loading_script_dir, data_dir, tmp_path, caplog):
1553
    cache_dir1 = tmp_path / "cache1"
1554
    cache_dir2 = tmp_path / "cache2"
1555
    dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="train", cache_dir=cache_dir1)
1556
    fingerprint1 = dataset._fingerprint
1557
    del dataset
1558
    os.rename(cache_dir1, cache_dir2)
1559
    caplog.clear()
1560
    with caplog.at_level(INFO, logger=get_logger().name):
1561
        dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="train", cache_dir=cache_dir2)
1562
    assert "Found cached dataset" in caplog.text
1563
    assert dataset._fingerprint == fingerprint1, "for the caching mechanism to work, fingerprint should stay the same"
1564
    dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="test", cache_dir=cache_dir2)
1565
    assert dataset._fingerprint != fingerprint1
1566

1567

1568
def test_load_dataset_builder_then_edit_then_load_again(tmp_path: Path):
1569
    dataset_dir = tmp_path / "test_load_dataset_then_edit_then_load_again"
1570
    dataset_dir.mkdir()
1571
    with open(dataset_dir / "train.txt", "w") as f:
1572
        f.write("Hello there")
1573
    dataset_builder = load_dataset_builder(str(dataset_dir))
1574
    with open(dataset_dir / "train.txt", "w") as f:
1575
        f.write("General Kenobi !")
1576
    edited_dataset_builder = load_dataset_builder(str(dataset_dir))
1577
    assert dataset_builder.cache_dir != edited_dataset_builder.cache_dir
1578

1579

1580
def test_load_dataset_readonly(dataset_loading_script_dir, dataset_loading_script_dir_readonly, data_dir, tmp_path):
1581
    cache_dir1 = tmp_path / "cache1"
1582
    cache_dir2 = tmp_path / "cache2"
1583
    dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="train", cache_dir=cache_dir1)
1584
    fingerprint1 = dataset._fingerprint
1585
    del dataset
1586
    # Load readonly dataset and check that the fingerprint is the same.
1587
    dataset = load_dataset(dataset_loading_script_dir_readonly, data_dir=data_dir, split="train", cache_dir=cache_dir2)
1588
    assert dataset._fingerprint == fingerprint1, "Cannot load a dataset in a readonly folder."
1589

1590

1591
@pytest.mark.parametrize("max_in_memory_dataset_size", ["default", 0, 50, 500])
1592
def test_load_dataset_local_with_default_in_memory(
1593
    max_in_memory_dataset_size, dataset_loading_script_dir, data_dir, monkeypatch
1594
):
1595
    current_dataset_size = 148
1596
    if max_in_memory_dataset_size == "default":
1597
        max_in_memory_dataset_size = 0  # default
1598
    else:
1599
        monkeypatch.setattr(datasets.config, "IN_MEMORY_MAX_SIZE", max_in_memory_dataset_size)
1600
    if max_in_memory_dataset_size:
1601
        expected_in_memory = current_dataset_size < max_in_memory_dataset_size
1602
    else:
1603
        expected_in_memory = False
1604

1605
    with assert_arrow_memory_increases() if expected_in_memory else assert_arrow_memory_doesnt_increase():
1606
        dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir)
1607
    assert (dataset["train"].dataset_size < max_in_memory_dataset_size) is expected_in_memory
1608

1609

1610
@pytest.mark.parametrize("max_in_memory_dataset_size", ["default", 0, 100, 1000])
1611
def test_load_from_disk_with_default_in_memory(
1612
    max_in_memory_dataset_size, dataset_loading_script_dir, data_dir, tmp_path, monkeypatch
1613
):
1614
    current_dataset_size = 512  # arrow file size = 512, in-memory dataset size = 148
1615
    if max_in_memory_dataset_size == "default":
1616
        max_in_memory_dataset_size = 0  # default
1617
    else:
1618
        monkeypatch.setattr(datasets.config, "IN_MEMORY_MAX_SIZE", max_in_memory_dataset_size)
1619
    if max_in_memory_dataset_size:
1620
        expected_in_memory = current_dataset_size < max_in_memory_dataset_size
1621
    else:
1622
        expected_in_memory = False
1623

1624
    dset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, keep_in_memory=True)
1625
    dataset_path = os.path.join(tmp_path, "saved_dataset")
1626
    dset.save_to_disk(dataset_path)
1627

1628
    with assert_arrow_memory_increases() if expected_in_memory else assert_arrow_memory_doesnt_increase():
1629
        _ = load_from_disk(dataset_path)
1630

1631

1632
@pytest.mark.integration
1633
def test_remote_data_files():
1634
    repo_id = "hf-internal-testing/raw_jsonl"
1635
    filename = "wikiann-bn-validation.jsonl"
1636
    data_files = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{filename}"
1637
    ds = load_dataset("json", split="train", data_files=data_files, streaming=True)
1638
    assert isinstance(ds, IterableDataset)
1639
    ds_item = next(iter(ds))
1640
    assert ds_item.keys() == {"langs", "ner_tags", "spans", "tokens"}
1641

1642

1643
@pytest.mark.parametrize("deleted", [False, True])
1644
def test_load_dataset_deletes_extracted_files(deleted, jsonl_gz_path, tmp_path):
1645
    data_files = jsonl_gz_path
1646
    cache_dir = tmp_path / "cache"
1647
    if deleted:
1648
        download_config = DownloadConfig(delete_extracted=True, cache_dir=cache_dir / "downloads")
1649
        ds = load_dataset(
1650
            "json", split="train", data_files=data_files, cache_dir=cache_dir, download_config=download_config
1651
        )
1652
    else:  # default
1653
        ds = load_dataset("json", split="train", data_files=data_files, cache_dir=cache_dir)
1654
    assert ds[0] == {"col_1": "0", "col_2": 0, "col_3": 0.0}
1655
    assert (
1656
        [path for path in (cache_dir / "downloads" / "extracted").iterdir() if path.suffix != ".lock"] == []
1657
    ) is deleted
1658

1659

1660
def distributed_load_dataset(args):
1661
    data_name, tmp_dir, datafiles = args
1662
    dataset = load_dataset(data_name, cache_dir=tmp_dir, data_files=datafiles)
1663
    return dataset
1664

1665

1666
def test_load_dataset_distributed(tmp_path, csv_path):
1667
    num_workers = 5
1668
    args = "csv", str(tmp_path), csv_path
1669
    with Pool(processes=num_workers) as pool:  # start num_workers processes
1670
        datasets = pool.map(distributed_load_dataset, [args] * num_workers)
1671
        assert len(datasets) == num_workers
1672
        assert all(len(dataset) == len(datasets[0]) > 0 for dataset in datasets)
1673
        assert len(datasets[0].cache_files) > 0
1674
        assert all(dataset.cache_files == datasets[0].cache_files for dataset in datasets)
1675

1676

1677
def test_load_dataset_with_storage_options(mockfs):
1678
    with mockfs.open("data.txt", "w") as f:
1679
        f.write("Hello there\n")
1680
        f.write("General Kenobi !")
1681
    data_files = {"train": ["mock://data.txt"]}
1682
    ds = load_dataset("text", data_files=data_files, storage_options=mockfs.storage_options)
1683
    assert list(ds["train"]) == [{"text": "Hello there"}, {"text": "General Kenobi !"}]
1684

1685

1686
@require_pil
1687
def test_load_dataset_with_storage_options_with_decoding(mockfs, image_file):
1688
    import PIL.Image
1689

1690
    filename = os.path.basename(image_file)
1691
    with mockfs.open(filename, "wb") as fout:
1692
        with open(image_file, "rb") as fin:
1693
            fout.write(fin.read())
1694
    data_files = {"train": ["mock://" + filename]}
1695
    ds = load_dataset("imagefolder", data_files=data_files, storage_options=mockfs.storage_options)
1696
    assert len(ds["train"]) == 1
1697
    assert isinstance(ds["train"][0]["image"], PIL.Image.Image)
1698

1699

1700
def test_load_dataset_without_script_with_zip(zip_csv_path):
1701
    path = str(zip_csv_path.parent)
1702
    ds = load_dataset(path)
1703
    assert list(ds.keys()) == ["train"]
1704
    assert ds["train"].column_names == ["col_1", "col_2", "col_3"]
1705
    assert ds["train"].num_rows == 8
1706
    assert ds["train"][0] == {"col_1": 0, "col_2": 0, "col_3": 0.0}
1707

1708

1709
@pytest.mark.parametrize("trust_remote_code, expected", [(False, False), (True, True), (None, True)])
1710
def test_resolve_trust_remote_code(trust_remote_code, expected):
1711
    assert resolve_trust_remote_code(trust_remote_code, repo_id="dummy") is expected
1712

1713

1714
@pytest.mark.parametrize("trust_remote_code, expected", [(False, False), (True, True), (None, ValueError)])
1715
def test_resolve_trust_remote_code_future(trust_remote_code, expected):
1716
    with patch.object(config, "HF_DATASETS_TRUST_REMOTE_CODE", None):  # this will be the default soon
1717
        if isinstance(expected, bool):
1718
            resolve_trust_remote_code(trust_remote_code, repo_id="dummy") is expected
1719
        else:
1720
            with pytest.raises(expected):
1721
                resolve_trust_remote_code(trust_remote_code, repo_id="dummy")
1722

1723

1724
@pytest.mark.integration
1725
def test_reload_old_cache_from_2_15(tmp_path: Path):
1726
    cache_dir = tmp_path / "test_reload_old_cache_from_2_15"
1727
    builder_cache_dir = (
1728
        cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata/v2-374bfde4f55442bc/0.0.0/7896925d64deea5d"
1729
    )
1730
    builder_cache_dir.mkdir(parents=True)
1731
    arrow_path = builder_cache_dir / "audiofolder_two_configs_in_metadata-train.arrow"
1732
    dataset_info_path = builder_cache_dir / "dataset_info.json"
1733
    with dataset_info_path.open("w") as f:
1734
        f.write("{}")
1735
    arrow_path.touch()
1736
    builder = load_dataset_builder(
1737
        "polinaeterna/audiofolder_two_configs_in_metadata",
1738
        "v2",
1739
        data_files="v2/train/*",
1740
        cache_dir=cache_dir.as_posix(),
1741
    )
1742
    assert builder.cache_dir == builder_cache_dir.as_posix()  # old cache from 2.15
1743

1744
    builder = load_dataset_builder(
1745
        "polinaeterna/audiofolder_two_configs_in_metadata", "v2", cache_dir=cache_dir.as_posix()
1746
    )
1747
    assert (
1748
        builder.cache_dir
1749
        == (
1750
            cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata" / "v2" / "0.0.0" / str(builder.hash)
1751
        ).as_posix()
1752
    )  # new cache
1753

1754

1755
@pytest.mark.integration
1756
def test_update_dataset_card_data_with_standalone_yaml():
1757
    # Labels defined in .huggingface.yml because they are too long to be in README.md
1758
    from datasets.utils.metadata import MetadataConfigs
1759

1760
    with patch(
1761
        "datasets.utils.metadata.MetadataConfigs.from_dataset_card_data",
1762
        side_effect=MetadataConfigs.from_dataset_card_data,
1763
    ) as card_data_read_mock:
1764
        builder = load_dataset_builder("datasets-maintainers/dataset-with-standalone-yaml")
1765
    assert card_data_read_mock.call_args.args[0]["license"] is not None  # from README.md
1766
    assert card_data_read_mock.call_args.args[0]["dataset_info"] is not None  # from standalone yaml
1767
    assert card_data_read_mock.call_args.args[0]["tags"] == ["test"]  # standalone yaml has precedence
1768
    assert isinstance(
1769
        builder.info.features["label"], datasets.ClassLabel
1770
    )  # correctly loaded from long labels list in standalone yaml
1771

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.