datasets

Форк
0
/
test_hf_gcp.py 
107 строк · 4.0 Кб
1
import os
2
from tempfile import TemporaryDirectory
3
from unittest import TestCase
4

5
import pytest
6
from absl.testing import parameterized
7

8
from datasets import config
9
from datasets.arrow_reader import HF_GCP_BASE_URL
10
from datasets.builder import DatasetBuilder
11
from datasets.dataset_dict import IterableDatasetDict
12
from datasets.iterable_dataset import IterableDataset
13
from datasets.load import dataset_module_factory, import_main_class
14
from datasets.utils.file_utils import cached_path
15

16

17
DATASETS_ON_HF_GCP = [
18
    {"dataset": "wikipedia", "config_name": "20220301.de"},
19
    {"dataset": "wikipedia", "config_name": "20220301.en"},
20
    {"dataset": "wikipedia", "config_name": "20220301.fr"},
21
    {"dataset": "wikipedia", "config_name": "20220301.frr"},
22
    {"dataset": "wikipedia", "config_name": "20220301.it"},
23
    {"dataset": "wikipedia", "config_name": "20220301.simple"},
24
    {"dataset": "wiki40b", "config_name": "en"},
25
    {"dataset": "wiki_dpr", "config_name": "psgs_w100.nq.compressed"},
26
    {"dataset": "wiki_dpr", "config_name": "psgs_w100.nq.no_index"},
27
    {"dataset": "wiki_dpr", "config_name": "psgs_w100.multiset.no_index"},
28
    {"dataset": "natural_questions", "config_name": "default"},
29
]
30

31

32
def list_datasets_on_hf_gcp_parameters(with_config=True):
33
    if with_config:
34
        return [
35
            {
36
                "testcase_name": d["dataset"] + "/" + d["config_name"],
37
                "dataset": d["dataset"],
38
                "config_name": d["config_name"],
39
            }
40
            for d in DATASETS_ON_HF_GCP
41
        ]
42
    else:
43
        return [
44
            {"testcase_name": dataset, "dataset": dataset} for dataset in {d["dataset"] for d in DATASETS_ON_HF_GCP}
45
        ]
46

47

48
@parameterized.named_parameters(list_datasets_on_hf_gcp_parameters(with_config=True))
49
class TestDatasetOnHfGcp(TestCase):
50
    dataset = None
51
    config_name = None
52

53
    def test_dataset_info_available(self, dataset, config_name):
54
        with TemporaryDirectory() as tmp_dir:
55
            dataset_module = dataset_module_factory(dataset, cache_dir=tmp_dir)
56

57
            builder_cls = import_main_class(dataset_module.module_path, dataset=True)
58

59
            builder_instance: DatasetBuilder = builder_cls(
60
                cache_dir=tmp_dir,
61
                config_name=config_name,
62
                hash=dataset_module.hash,
63
            )
64

65
            dataset_info_url = "/".join(
66
                [
67
                    HF_GCP_BASE_URL,
68
                    builder_instance._relative_data_dir(with_hash=False).replace(os.sep, "/"),
69
                    config.DATASET_INFO_FILENAME,
70
                ]
71
            )
72
            datset_info_path = cached_path(dataset_info_url, cache_dir=tmp_dir)
73
            self.assertTrue(os.path.exists(datset_info_path))
74

75

76
@pytest.mark.integration
77
def test_as_dataset_from_hf_gcs(tmp_path_factory):
78
    tmp_dir = tmp_path_factory.mktemp("test_hf_gcp") / "test_wikipedia_simple"
79
    dataset_module = dataset_module_factory("wikipedia", cache_dir=tmp_dir)
80
    builder_cls = import_main_class(dataset_module.module_path)
81
    builder_instance: DatasetBuilder = builder_cls(
82
        cache_dir=tmp_dir,
83
        config_name="20220301.frr",
84
        hash=dataset_module.hash,
85
    )
86
    # use the HF cloud storage, not the original download_and_prepare that uses apache-beam
87
    builder_instance._download_and_prepare = None
88
    builder_instance.download_and_prepare()
89
    ds = builder_instance.as_dataset()
90
    assert ds
91

92

93
@pytest.mark.integration
94
def test_as_streaming_dataset_from_hf_gcs(tmp_path):
95
    dataset_module = dataset_module_factory("wikipedia", cache_dir=tmp_path)
96
    builder_cls = import_main_class(dataset_module.module_path, dataset=True)
97
    builder_instance: DatasetBuilder = builder_cls(
98
        cache_dir=tmp_path,
99
        config_name="20220301.frr",
100
        hash=dataset_module.hash,
101
    )
102
    ds = builder_instance.as_streaming_dataset()
103
    assert ds
104
    assert isinstance(ds, IterableDatasetDict)
105
    assert "train" in ds
106
    assert isinstance(ds["train"], IterableDataset)
107
    assert next(iter(ds["train"]))
108

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.