2
from tempfile import TemporaryDirectory
3
from unittest import TestCase
6
from absl.testing import parameterized
8
from datasets import config
9
from datasets.arrow_reader import HF_GCP_BASE_URL
10
from datasets.builder import DatasetBuilder
11
from datasets.dataset_dict import IterableDatasetDict
12
from datasets.iterable_dataset import IterableDataset
13
from datasets.load import dataset_module_factory, import_main_class
14
from datasets.utils.file_utils import cached_path
18
{"dataset": "wikipedia", "config_name": "20220301.de"},
19
{"dataset": "wikipedia", "config_name": "20220301.en"},
20
{"dataset": "wikipedia", "config_name": "20220301.fr"},
21
{"dataset": "wikipedia", "config_name": "20220301.frr"},
22
{"dataset": "wikipedia", "config_name": "20220301.it"},
23
{"dataset": "wikipedia", "config_name": "20220301.simple"},
24
{"dataset": "wiki40b", "config_name": "en"},
25
{"dataset": "wiki_dpr", "config_name": "psgs_w100.nq.compressed"},
26
{"dataset": "wiki_dpr", "config_name": "psgs_w100.nq.no_index"},
27
{"dataset": "wiki_dpr", "config_name": "psgs_w100.multiset.no_index"},
28
{"dataset": "natural_questions", "config_name": "default"},
32
def list_datasets_on_hf_gcp_parameters(with_config=True):
36
"testcase_name": d["dataset"] + "/" + d["config_name"],
37
"dataset": d["dataset"],
38
"config_name": d["config_name"],
40
for d in DATASETS_ON_HF_GCP
44
{"testcase_name": dataset, "dataset": dataset} for dataset in {d["dataset"] for d in DATASETS_ON_HF_GCP}
48
@parameterized.named_parameters(list_datasets_on_hf_gcp_parameters(with_config=True))
49
class TestDatasetOnHfGcp(TestCase):
53
def test_dataset_info_available(self, dataset, config_name):
54
with TemporaryDirectory() as tmp_dir:
55
dataset_module = dataset_module_factory(dataset, cache_dir=tmp_dir)
57
builder_cls = import_main_class(dataset_module.module_path, dataset=True)
59
builder_instance: DatasetBuilder = builder_cls(
61
config_name=config_name,
62
hash=dataset_module.hash,
65
dataset_info_url = "/".join(
68
builder_instance._relative_data_dir(with_hash=False).replace(os.sep, "/"),
69
config.DATASET_INFO_FILENAME,
72
datset_info_path = cached_path(dataset_info_url, cache_dir=tmp_dir)
73
self.assertTrue(os.path.exists(datset_info_path))
76
@pytest.mark.integration
77
def test_as_dataset_from_hf_gcs(tmp_path_factory):
78
tmp_dir = tmp_path_factory.mktemp("test_hf_gcp") / "test_wikipedia_simple"
79
dataset_module = dataset_module_factory("wikipedia", cache_dir=tmp_dir)
80
builder_cls = import_main_class(dataset_module.module_path)
81
builder_instance: DatasetBuilder = builder_cls(
83
config_name="20220301.frr",
84
hash=dataset_module.hash,
87
builder_instance._download_and_prepare = None
88
builder_instance.download_and_prepare()
89
ds = builder_instance.as_dataset()
93
@pytest.mark.integration
94
def test_as_streaming_dataset_from_hf_gcs(tmp_path):
95
dataset_module = dataset_module_factory("wikipedia", cache_dir=tmp_path)
96
builder_cls = import_main_class(dataset_module.module_path, dataset=True)
97
builder_instance: DatasetBuilder = builder_cls(
99
config_name="20220301.frr",
100
hash=dataset_module.hash,
102
ds = builder_instance.as_streaming_dataset()
104
assert isinstance(ds, IterableDatasetDict)
106
assert isinstance(ds["train"], IterableDataset)
107
assert next(iter(ds["train"]))