datasets

Форк
0
/
test_json.py 
270 строк · 11.7 Кб
1
import io
2
import json
3

4
import fsspec
5
import pytest
6

7
from datasets import Dataset, DatasetDict, Features, NamedSplit, Value
8
from datasets.io.json import JsonDatasetReader, JsonDatasetWriter
9

10
from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases
11

12

13
def _check_json_dataset(dataset, expected_features):
14
    assert isinstance(dataset, Dataset)
15
    assert dataset.num_rows == 4
16
    assert dataset.num_columns == 3
17
    assert dataset.column_names == ["col_1", "col_2", "col_3"]
18
    for feature, expected_dtype in expected_features.items():
19
        assert dataset.features[feature].dtype == expected_dtype
20

21

22
@pytest.mark.parametrize("keep_in_memory", [False, True])
23
def test_dataset_from_json_keep_in_memory(keep_in_memory, jsonl_path, tmp_path):
24
    cache_dir = tmp_path / "cache"
25
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
26
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
27
        dataset = JsonDatasetReader(jsonl_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory).read()
28
    _check_json_dataset(dataset, expected_features)
29

30

31
@pytest.mark.parametrize(
32
    "features",
33
    [
34
        None,
35
        {"col_1": "string", "col_2": "int64", "col_3": "float64"},
36
        {"col_1": "string", "col_2": "string", "col_3": "string"},
37
        {"col_1": "int32", "col_2": "int32", "col_3": "int32"},
38
        {"col_1": "float32", "col_2": "float32", "col_3": "float32"},
39
    ],
40
)
41
def test_dataset_from_json_features(features, jsonl_path, tmp_path):
42
    cache_dir = tmp_path / "cache"
43
    default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
44
    expected_features = features.copy() if features else default_expected_features
45
    features = (
46
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
47
    )
48
    dataset = JsonDatasetReader(jsonl_path, features=features, cache_dir=cache_dir).read()
49
    _check_json_dataset(dataset, expected_features)
50

51

52
@pytest.mark.parametrize(
53
    "features",
54
    [
55
        None,
56
        {"col_3": "float64", "col_1": "string", "col_2": "int64"},
57
    ],
58
)
59
def test_dataset_from_json_with_unsorted_column_names(features, jsonl_312_path, tmp_path):
60
    cache_dir = tmp_path / "cache"
61
    default_expected_features = {"col_3": "float64", "col_1": "string", "col_2": "int64"}
62
    expected_features = features.copy() if features else default_expected_features
63
    features = (
64
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
65
    )
66
    dataset = JsonDatasetReader(jsonl_312_path, features=features, cache_dir=cache_dir).read()
67
    assert isinstance(dataset, Dataset)
68
    assert dataset.num_rows == 2
69
    assert dataset.num_columns == 3
70
    assert dataset.column_names == ["col_3", "col_1", "col_2"]
71
    for feature, expected_dtype in expected_features.items():
72
        assert dataset.features[feature].dtype == expected_dtype
73

74

75
def test_dataset_from_json_with_mismatched_features(jsonl_312_path, tmp_path):
76
    # jsonl_312_path features are {"col_3": "float64", "col_1": "string", "col_2": "int64"}
77
    features = {"col_2": "int64", "col_3": "float64", "col_1": "string"}
78
    expected_features = features.copy()
79
    features = (
80
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
81
    )
82
    cache_dir = tmp_path / "cache"
83
    dataset = JsonDatasetReader(jsonl_312_path, features=features, cache_dir=cache_dir).read()
84
    assert isinstance(dataset, Dataset)
85
    assert dataset.num_rows == 2
86
    assert dataset.num_columns == 3
87
    assert dataset.column_names == ["col_2", "col_3", "col_1"]
88
    for feature, expected_dtype in expected_features.items():
89
        assert dataset.features[feature].dtype == expected_dtype
90

91

92
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
93
def test_dataset_from_json_split(split, jsonl_path, tmp_path):
94
    cache_dir = tmp_path / "cache"
95
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
96
    dataset = JsonDatasetReader(jsonl_path, cache_dir=cache_dir, split=split).read()
97
    _check_json_dataset(dataset, expected_features)
98
    assert dataset.split == split if split else "train"
99

100

101
@pytest.mark.parametrize("path_type", [str, list])
102
def test_dataset_from_json_path_type(path_type, jsonl_path, tmp_path):
103
    if issubclass(path_type, str):
104
        path = jsonl_path
105
    elif issubclass(path_type, list):
106
        path = [jsonl_path]
107
    cache_dir = tmp_path / "cache"
108
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
109
    dataset = JsonDatasetReader(path, cache_dir=cache_dir).read()
110
    _check_json_dataset(dataset, expected_features)
111

112

113
def _check_json_datasetdict(dataset_dict, expected_features, splits=("train",)):
114
    assert isinstance(dataset_dict, DatasetDict)
115
    for split in splits:
116
        dataset = dataset_dict[split]
117
        assert dataset.num_rows == 4
118
        assert dataset.num_columns == 3
119
        assert dataset.column_names == ["col_1", "col_2", "col_3"]
120
        for feature, expected_dtype in expected_features.items():
121
            assert dataset.features[feature].dtype == expected_dtype
122

123

124
@pytest.mark.parametrize("keep_in_memory", [False, True])
125
def test_datasetdict_from_json_keep_in_memory(keep_in_memory, jsonl_path, tmp_path):
126
    cache_dir = tmp_path / "cache"
127
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
128
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
129
        dataset = JsonDatasetReader({"train": jsonl_path}, cache_dir=cache_dir, keep_in_memory=keep_in_memory).read()
130
    _check_json_datasetdict(dataset, expected_features)
131

132

133
@pytest.mark.parametrize(
134
    "features",
135
    [
136
        None,
137
        {"col_1": "string", "col_2": "int64", "col_3": "float64"},
138
        {"col_1": "string", "col_2": "string", "col_3": "string"},
139
        {"col_1": "int32", "col_2": "int32", "col_3": "int32"},
140
        {"col_1": "float32", "col_2": "float32", "col_3": "float32"},
141
    ],
142
)
143
def test_datasetdict_from_json_features(features, jsonl_path, tmp_path):
144
    cache_dir = tmp_path / "cache"
145
    default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
146
    expected_features = features.copy() if features else default_expected_features
147
    features = (
148
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
149
    )
150
    dataset = JsonDatasetReader({"train": jsonl_path}, features=features, cache_dir=cache_dir).read()
151
    _check_json_datasetdict(dataset, expected_features)
152

153

154
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
155
def test_datasetdict_from_json_splits(split, jsonl_path, tmp_path):
156
    if split:
157
        path = {split: jsonl_path}
158
    else:
159
        split = "train"
160
        path = {"train": jsonl_path, "test": jsonl_path}
161
    cache_dir = tmp_path / "cache"
162
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
163
    dataset = JsonDatasetReader(path, cache_dir=cache_dir).read()
164
    _check_json_datasetdict(dataset, expected_features, splits=list(path.keys()))
165
    assert all(dataset[split].split == split for split in path.keys())
166

167

168
def load_json(buffer):
169
    return json.load(buffer)
170

171

172
def load_json_lines(buffer):
173
    return [json.loads(line) for line in buffer]
174

175

176
class TestJsonDatasetWriter:
177
    @pytest.mark.parametrize("lines, load_json_function", [(True, load_json_lines), (False, load_json)])
178
    def test_dataset_to_json_lines(self, lines, load_json_function, dataset):
179
        with io.BytesIO() as buffer:
180
            JsonDatasetWriter(dataset, buffer, lines=lines).write()
181
            buffer.seek(0)
182
            exported_content = load_json_function(buffer)
183
        assert isinstance(exported_content, list)
184
        assert isinstance(exported_content[0], dict)
185
        assert len(exported_content) == 10
186

187
    @pytest.mark.parametrize(
188
        "orient, container, keys, len_at",
189
        [
190
            ("records", list, {"tokens", "labels", "answers", "id"}, None),
191
            ("split", dict, {"columns", "data"}, "data"),
192
            ("index", dict, set("0123456789"), None),
193
            ("columns", dict, {"tokens", "labels", "answers", "id"}, "tokens"),
194
            ("values", list, None, None),
195
            ("table", dict, {"schema", "data"}, "data"),
196
        ],
197
    )
198
    def test_dataset_to_json_orient(self, orient, container, keys, len_at, dataset):
199
        with io.BytesIO() as buffer:
200
            JsonDatasetWriter(dataset, buffer, lines=False, orient=orient).write()
201
            buffer.seek(0)
202
            exported_content = load_json(buffer)
203
        assert isinstance(exported_content, container)
204
        if keys:
205
            if container is dict:
206
                assert exported_content.keys() == keys
207
            else:
208
                assert exported_content[0].keys() == keys
209
        else:
210
            assert not hasattr(exported_content, "keys") and not hasattr(exported_content[0], "keys")
211
        if len_at:
212
            assert len(exported_content[len_at]) == 10
213
        else:
214
            assert len(exported_content) == 10
215

216
    @pytest.mark.parametrize("lines, load_json_function", [(True, load_json_lines), (False, load_json)])
217
    def test_dataset_to_json_lines_multiproc(self, lines, load_json_function, dataset):
218
        with io.BytesIO() as buffer:
219
            JsonDatasetWriter(dataset, buffer, lines=lines, num_proc=2).write()
220
            buffer.seek(0)
221
            exported_content = load_json_function(buffer)
222
        assert isinstance(exported_content, list)
223
        assert isinstance(exported_content[0], dict)
224
        assert len(exported_content) == 10
225

226
    @pytest.mark.parametrize(
227
        "orient, container, keys, len_at",
228
        [
229
            ("records", list, {"tokens", "labels", "answers", "id"}, None),
230
            ("split", dict, {"columns", "data"}, "data"),
231
            ("index", dict, set("0123456789"), None),
232
            ("columns", dict, {"tokens", "labels", "answers", "id"}, "tokens"),
233
            ("values", list, None, None),
234
            ("table", dict, {"schema", "data"}, "data"),
235
        ],
236
    )
237
    def test_dataset_to_json_orient_multiproc(self, orient, container, keys, len_at, dataset):
238
        with io.BytesIO() as buffer:
239
            JsonDatasetWriter(dataset, buffer, lines=False, orient=orient, num_proc=2).write()
240
            buffer.seek(0)
241
            exported_content = load_json(buffer)
242
        assert isinstance(exported_content, container)
243
        if keys:
244
            if container is dict:
245
                assert exported_content.keys() == keys
246
            else:
247
                assert exported_content[0].keys() == keys
248
        else:
249
            assert not hasattr(exported_content, "keys") and not hasattr(exported_content[0], "keys")
250
        if len_at:
251
            assert len(exported_content[len_at]) == 10
252
        else:
253
            assert len(exported_content) == 10
254

255
    def test_dataset_to_json_orient_invalidproc(self, dataset):
256
        with pytest.raises(ValueError):
257
            with io.BytesIO() as buffer:
258
                JsonDatasetWriter(dataset, buffer, num_proc=0)
259

260
    @pytest.mark.parametrize("compression, extension", [("gzip", "gz"), ("bz2", "bz2"), ("xz", "xz")])
261
    def test_dataset_to_json_compression(self, shared_datadir, tmp_path_factory, extension, compression, dataset):
262
        path = tmp_path_factory.mktemp("data") / f"test.json.{extension}"
263
        original_path = str(shared_datadir / f"test_file.json.{extension}")
264
        JsonDatasetWriter(dataset, path, compression=compression).write()
265

266
        with fsspec.open(path, "rb", compression="infer") as f:
267
            exported_content = f.read()
268
        with fsspec.open(original_path, "rb", compression="infer") as f:
269
            original_content = f.read()
270
        assert exported_content == original_content
271

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.