datasets

Форк
0
/
test_info.py 
166 строк · 5.5 Кб
1
import os
2

3
import pytest
4
import yaml
5

6
from datasets.features.features import Features, Value
7
from datasets.info import DatasetInfo, DatasetInfosDict
8

9

10
@pytest.mark.parametrize(
11
    "files",
12
    [
13
        ["full:README.md", "dataset_infos.json"],
14
        ["empty:README.md", "dataset_infos.json"],
15
        ["dataset_infos.json"],
16
        ["full:README.md"],
17
    ],
18
)
19
def test_from_dir(files, tmp_path_factory):
20
    dataset_infos_dir = tmp_path_factory.mktemp("dset_infos_dir")
21
    if "full:README.md" in files:
22
        with open(dataset_infos_dir / "README.md", "w") as f:
23
            f.write("---\ndataset_info:\n  dataset_size: 42\n---")
24
    if "empty:README.md" in files:
25
        with open(dataset_infos_dir / "README.md", "w") as f:
26
            f.write("")
27
    # we want to support dataset_infos.json for backward compatibility
28
    if "dataset_infos.json" in files:
29
        with open(dataset_infos_dir / "dataset_infos.json", "w") as f:
30
            f.write('{"default": {"dataset_size": 42}}')
31
    dataset_infos = DatasetInfosDict.from_directory(dataset_infos_dir)
32
    assert dataset_infos
33
    assert dataset_infos["default"].dataset_size == 42
34

35

36
@pytest.mark.parametrize(
37
    "dataset_info",
38
    [
39
        DatasetInfo(),
40
        DatasetInfo(
41
            description="foo",
42
            features=Features({"a": Value("int32")}),
43
            builder_name="builder",
44
            config_name="config",
45
            version="1.0.0",
46
            splits=[{"name": "train"}],
47
            download_size=42,
48
        ),
49
    ],
50
)
51
def test_dataset_info_dump_and_reload(tmp_path, dataset_info: DatasetInfo):
52
    tmp_path = str(tmp_path)
53
    dataset_info.write_to_directory(tmp_path)
54
    reloaded = DatasetInfo.from_directory(tmp_path)
55
    assert dataset_info == reloaded
56
    assert os.path.exists(os.path.join(tmp_path, "dataset_info.json"))
57

58

59
def test_dataset_info_to_yaml_dict():
60
    dataset_info = DatasetInfo(
61
        description="foo",
62
        citation="bar",
63
        homepage="https://foo.bar",
64
        license="CC0",
65
        features=Features({"a": Value("int32")}),
66
        post_processed={},
67
        supervised_keys=(),
68
        task_templates=[],
69
        builder_name="builder",
70
        config_name="config",
71
        version="1.0.0",
72
        splits=[{"name": "train", "num_examples": 42}],
73
        download_checksums={},
74
        download_size=1337,
75
        post_processing_size=442,
76
        dataset_size=1234,
77
        size_in_bytes=1337 + 442 + 1234,
78
    )
79
    dataset_info_yaml_dict = dataset_info._to_yaml_dict()
80
    assert sorted(dataset_info_yaml_dict) == sorted(DatasetInfo._INCLUDED_INFO_IN_YAML)
81
    for key in DatasetInfo._INCLUDED_INFO_IN_YAML:
82
        assert key in dataset_info_yaml_dict
83
        assert isinstance(dataset_info_yaml_dict[key], (list, dict, int, str))
84
    dataset_info_yaml = yaml.safe_dump(dataset_info_yaml_dict)
85
    reloaded = yaml.safe_load(dataset_info_yaml)
86
    assert dataset_info_yaml_dict == reloaded
87

88

89
def test_dataset_info_to_yaml_dict_empty():
90
    dataset_info = DatasetInfo()
91
    dataset_info_yaml_dict = dataset_info._to_yaml_dict()
92
    assert dataset_info_yaml_dict == {}
93

94

95
@pytest.mark.parametrize(
96
    "dataset_infos_dict",
97
    [
98
        DatasetInfosDict(),
99
        DatasetInfosDict({"default": DatasetInfo()}),
100
        DatasetInfosDict({"my_config_name": DatasetInfo()}),
101
        DatasetInfosDict(
102
            {
103
                "default": DatasetInfo(
104
                    description="foo",
105
                    features=Features({"a": Value("int32")}),
106
                    builder_name="builder",
107
                    config_name="config",
108
                    version="1.0.0",
109
                    splits=[{"name": "train"}],
110
                    download_size=42,
111
                )
112
            }
113
        ),
114
        DatasetInfosDict(
115
            {
116
                "v1": DatasetInfo(dataset_size=42),
117
                "v2": DatasetInfo(dataset_size=1337),
118
            }
119
        ),
120
    ],
121
)
122
def test_dataset_infos_dict_dump_and_reload(tmp_path, dataset_infos_dict: DatasetInfosDict):
123
    tmp_path = str(tmp_path)
124
    dataset_infos_dict.write_to_directory(tmp_path)
125
    reloaded = DatasetInfosDict.from_directory(tmp_path)
126

127
    # the config_name of the dataset_infos_dict take over the attribute
128
    for config_name, dataset_info in dataset_infos_dict.items():
129
        dataset_info.config_name = config_name
130
        # the yaml representation doesn't include fields like description or citation
131
        # so we just test that we can recover what we can from the yaml
132
        dataset_infos_dict[config_name] = DatasetInfo._from_yaml_dict(dataset_info._to_yaml_dict())
133
    assert dataset_infos_dict == reloaded
134

135
    if dataset_infos_dict:
136
        assert os.path.exists(os.path.join(tmp_path, "README.md"))
137

138

139
@pytest.mark.parametrize(
140
    "dataset_info",
141
    [
142
        None,
143
        DatasetInfo(),
144
        DatasetInfo(
145
            description="foo",
146
            features=Features({"a": Value("int32")}),
147
            builder_name="builder",
148
            config_name="config",
149
            version="1.0.0",
150
            splits=[{"name": "train"}],
151
            download_size=42,
152
            dataset_name="dataset_name",
153
        ),
154
    ],
155
)
156
def test_from_merge_same_dataset_infos(dataset_info):
157
    num_elements = 3
158
    if dataset_info is not None:
159
        dataset_info_list = [dataset_info.copy() for _ in range(num_elements)]
160
    else:
161
        dataset_info_list = [None] * num_elements
162
    dataset_info_merged = DatasetInfo.from_merge(dataset_info_list)
163
    if dataset_info is not None:
164
        assert dataset_info == dataset_info_merged
165
    else:
166
        assert DatasetInfo() == dataset_info_merged
167

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.