7
from datasets import Dataset, DatasetDict, Features, NamedSplit, Value
8
from datasets.io.json import JsonDatasetReader, JsonDatasetWriter
10
from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases
13
def _check_json_dataset(dataset, expected_features):
14
assert isinstance(dataset, Dataset)
15
assert dataset.num_rows == 4
16
assert dataset.num_columns == 3
17
assert dataset.column_names == ["col_1", "col_2", "col_3"]
18
for feature, expected_dtype in expected_features.items():
19
assert dataset.features[feature].dtype == expected_dtype
22
@pytest.mark.parametrize("keep_in_memory", [False, True])
23
def test_dataset_from_json_keep_in_memory(keep_in_memory, jsonl_path, tmp_path):
24
cache_dir = tmp_path / "cache"
25
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
26
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
27
dataset = JsonDatasetReader(jsonl_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory).read()
28
_check_json_dataset(dataset, expected_features)
31
@pytest.mark.parametrize(
35
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
36
{"col_1": "string", "col_2": "string", "col_3": "string"},
37
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
38
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
41
def test_dataset_from_json_features(features, jsonl_path, tmp_path):
42
cache_dir = tmp_path / "cache"
43
default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
44
expected_features = features.copy() if features else default_expected_features
46
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
48
dataset = JsonDatasetReader(jsonl_path, features=features, cache_dir=cache_dir).read()
49
_check_json_dataset(dataset, expected_features)
52
@pytest.mark.parametrize(
56
{"col_3": "float64", "col_1": "string", "col_2": "int64"},
59
def test_dataset_from_json_with_unsorted_column_names(features, jsonl_312_path, tmp_path):
60
cache_dir = tmp_path / "cache"
61
default_expected_features = {"col_3": "float64", "col_1": "string", "col_2": "int64"}
62
expected_features = features.copy() if features else default_expected_features
64
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
66
dataset = JsonDatasetReader(jsonl_312_path, features=features, cache_dir=cache_dir).read()
67
assert isinstance(dataset, Dataset)
68
assert dataset.num_rows == 2
69
assert dataset.num_columns == 3
70
assert dataset.column_names == ["col_3", "col_1", "col_2"]
71
for feature, expected_dtype in expected_features.items():
72
assert dataset.features[feature].dtype == expected_dtype
75
def test_dataset_from_json_with_mismatched_features(jsonl_312_path, tmp_path):
77
features = {"col_2": "int64", "col_3": "float64", "col_1": "string"}
78
expected_features = features.copy()
80
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
82
cache_dir = tmp_path / "cache"
83
dataset = JsonDatasetReader(jsonl_312_path, features=features, cache_dir=cache_dir).read()
84
assert isinstance(dataset, Dataset)
85
assert dataset.num_rows == 2
86
assert dataset.num_columns == 3
87
assert dataset.column_names == ["col_2", "col_3", "col_1"]
88
for feature, expected_dtype in expected_features.items():
89
assert dataset.features[feature].dtype == expected_dtype
92
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
93
def test_dataset_from_json_split(split, jsonl_path, tmp_path):
94
cache_dir = tmp_path / "cache"
95
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
96
dataset = JsonDatasetReader(jsonl_path, cache_dir=cache_dir, split=split).read()
97
_check_json_dataset(dataset, expected_features)
98
assert dataset.split == split if split else "train"
101
@pytest.mark.parametrize("path_type", [str, list])
102
def test_dataset_from_json_path_type(path_type, jsonl_path, tmp_path):
103
if issubclass(path_type, str):
105
elif issubclass(path_type, list):
107
cache_dir = tmp_path / "cache"
108
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
109
dataset = JsonDatasetReader(path, cache_dir=cache_dir).read()
110
_check_json_dataset(dataset, expected_features)
113
def _check_json_datasetdict(dataset_dict, expected_features, splits=("train",)):
114
assert isinstance(dataset_dict, DatasetDict)
116
dataset = dataset_dict[split]
117
assert dataset.num_rows == 4
118
assert dataset.num_columns == 3
119
assert dataset.column_names == ["col_1", "col_2", "col_3"]
120
for feature, expected_dtype in expected_features.items():
121
assert dataset.features[feature].dtype == expected_dtype
124
@pytest.mark.parametrize("keep_in_memory", [False, True])
125
def test_datasetdict_from_json_keep_in_memory(keep_in_memory, jsonl_path, tmp_path):
126
cache_dir = tmp_path / "cache"
127
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
128
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
129
dataset = JsonDatasetReader({"train": jsonl_path}, cache_dir=cache_dir, keep_in_memory=keep_in_memory).read()
130
_check_json_datasetdict(dataset, expected_features)
133
@pytest.mark.parametrize(
137
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
138
{"col_1": "string", "col_2": "string", "col_3": "string"},
139
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
140
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
143
def test_datasetdict_from_json_features(features, jsonl_path, tmp_path):
144
cache_dir = tmp_path / "cache"
145
default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
146
expected_features = features.copy() if features else default_expected_features
148
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
150
dataset = JsonDatasetReader({"train": jsonl_path}, features=features, cache_dir=cache_dir).read()
151
_check_json_datasetdict(dataset, expected_features)
154
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
155
def test_datasetdict_from_json_splits(split, jsonl_path, tmp_path):
157
path = {split: jsonl_path}
160
path = {"train": jsonl_path, "test": jsonl_path}
161
cache_dir = tmp_path / "cache"
162
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
163
dataset = JsonDatasetReader(path, cache_dir=cache_dir).read()
164
_check_json_datasetdict(dataset, expected_features, splits=list(path.keys()))
165
assert all(dataset[split].split == split for split in path.keys())
168
def load_json(buffer):
169
return json.load(buffer)
172
def load_json_lines(buffer):
173
return [json.loads(line) for line in buffer]
176
class TestJsonDatasetWriter:
177
@pytest.mark.parametrize("lines, load_json_function", [(True, load_json_lines), (False, load_json)])
178
def test_dataset_to_json_lines(self, lines, load_json_function, dataset):
179
with io.BytesIO() as buffer:
180
JsonDatasetWriter(dataset, buffer, lines=lines).write()
182
exported_content = load_json_function(buffer)
183
assert isinstance(exported_content, list)
184
assert isinstance(exported_content[0], dict)
185
assert len(exported_content) == 10
187
@pytest.mark.parametrize(
188
"orient, container, keys, len_at",
190
("records", list, {"tokens", "labels", "answers", "id"}, None),
191
("split", dict, {"columns", "data"}, "data"),
192
("index", dict, set("0123456789"), None),
193
("columns", dict, {"tokens", "labels", "answers", "id"}, "tokens"),
194
("values", list, None, None),
195
("table", dict, {"schema", "data"}, "data"),
198
def test_dataset_to_json_orient(self, orient, container, keys, len_at, dataset):
199
with io.BytesIO() as buffer:
200
JsonDatasetWriter(dataset, buffer, lines=False, orient=orient).write()
202
exported_content = load_json(buffer)
203
assert isinstance(exported_content, container)
205
if container is dict:
206
assert exported_content.keys() == keys
208
assert exported_content[0].keys() == keys
210
assert not hasattr(exported_content, "keys") and not hasattr(exported_content[0], "keys")
212
assert len(exported_content[len_at]) == 10
214
assert len(exported_content) == 10
216
@pytest.mark.parametrize("lines, load_json_function", [(True, load_json_lines), (False, load_json)])
217
def test_dataset_to_json_lines_multiproc(self, lines, load_json_function, dataset):
218
with io.BytesIO() as buffer:
219
JsonDatasetWriter(dataset, buffer, lines=lines, num_proc=2).write()
221
exported_content = load_json_function(buffer)
222
assert isinstance(exported_content, list)
223
assert isinstance(exported_content[0], dict)
224
assert len(exported_content) == 10
226
@pytest.mark.parametrize(
227
"orient, container, keys, len_at",
229
("records", list, {"tokens", "labels", "answers", "id"}, None),
230
("split", dict, {"columns", "data"}, "data"),
231
("index", dict, set("0123456789"), None),
232
("columns", dict, {"tokens", "labels", "answers", "id"}, "tokens"),
233
("values", list, None, None),
234
("table", dict, {"schema", "data"}, "data"),
237
def test_dataset_to_json_orient_multiproc(self, orient, container, keys, len_at, dataset):
238
with io.BytesIO() as buffer:
239
JsonDatasetWriter(dataset, buffer, lines=False, orient=orient, num_proc=2).write()
241
exported_content = load_json(buffer)
242
assert isinstance(exported_content, container)
244
if container is dict:
245
assert exported_content.keys() == keys
247
assert exported_content[0].keys() == keys
249
assert not hasattr(exported_content, "keys") and not hasattr(exported_content[0], "keys")
251
assert len(exported_content[len_at]) == 10
253
assert len(exported_content) == 10
255
def test_dataset_to_json_orient_invalidproc(self, dataset):
256
with pytest.raises(ValueError):
257
with io.BytesIO() as buffer:
258
JsonDatasetWriter(dataset, buffer, num_proc=0)
260
@pytest.mark.parametrize("compression, extension", [("gzip", "gz"), ("bz2", "bz2"), ("xz", "xz")])
261
def test_dataset_to_json_compression(self, shared_datadir, tmp_path_factory, extension, compression, dataset):
262
path = tmp_path_factory.mktemp("data") / f"test.json.{extension}"
263
original_path = str(shared_datadir / f"test_file.json.{extension}")
264
JsonDatasetWriter(dataset, path, compression=compression).write()
266
with fsspec.open(path, "rb", compression="infer") as f:
267
exported_content = f.read()
268
with fsspec.open(original_path, "rb", compression="infer") as f:
269
original_content = f.read()
270
assert exported_content == original_content