6
from datasets import Dataset, DatasetDict, Features, NamedSplit, Value
7
from datasets.io.csv import CsvDatasetReader, CsvDatasetWriter
9
from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases
12
def _check_csv_dataset(dataset, expected_features):
13
assert isinstance(dataset, Dataset)
14
assert dataset.num_rows == 4
15
assert dataset.num_columns == 3
16
assert dataset.column_names == ["col_1", "col_2", "col_3"]
17
for feature, expected_dtype in expected_features.items():
18
assert dataset.features[feature].dtype == expected_dtype
21
@pytest.mark.parametrize("keep_in_memory", [False, True])
22
def test_dataset_from_csv_keep_in_memory(keep_in_memory, csv_path, tmp_path):
23
cache_dir = tmp_path / "cache"
24
expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
25
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
26
dataset = CsvDatasetReader(csv_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory).read()
27
_check_csv_dataset(dataset, expected_features)
30
@pytest.mark.parametrize(
34
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
35
{"col_1": "string", "col_2": "string", "col_3": "string"},
36
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
37
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
40
def test_dataset_from_csv_features(features, csv_path, tmp_path):
41
cache_dir = tmp_path / "cache"
43
default_expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
44
expected_features = features.copy() if features else default_expected_features
46
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
48
dataset = CsvDatasetReader(csv_path, features=features, cache_dir=cache_dir).read()
49
_check_csv_dataset(dataset, expected_features)
52
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
53
def test_dataset_from_csv_split(split, csv_path, tmp_path):
54
cache_dir = tmp_path / "cache"
55
expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
56
dataset = CsvDatasetReader(csv_path, cache_dir=cache_dir, split=split).read()
57
_check_csv_dataset(dataset, expected_features)
58
assert dataset.split == split if split else "train"
61
@pytest.mark.parametrize("path_type", [str, list])
62
def test_dataset_from_csv_path_type(path_type, csv_path, tmp_path):
63
if issubclass(path_type, str):
65
elif issubclass(path_type, list):
67
cache_dir = tmp_path / "cache"
68
expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
69
dataset = CsvDatasetReader(path, cache_dir=cache_dir).read()
70
_check_csv_dataset(dataset, expected_features)
73
def _check_csv_datasetdict(dataset_dict, expected_features, splits=("train",)):
74
assert isinstance(dataset_dict, DatasetDict)
76
dataset = dataset_dict[split]
77
assert dataset.num_rows == 4
78
assert dataset.num_columns == 3
79
assert dataset.column_names == ["col_1", "col_2", "col_3"]
80
for feature, expected_dtype in expected_features.items():
81
assert dataset.features[feature].dtype == expected_dtype
84
@pytest.mark.parametrize("keep_in_memory", [False, True])
85
def test_csv_datasetdict_reader_keep_in_memory(keep_in_memory, csv_path, tmp_path):
86
cache_dir = tmp_path / "cache"
87
expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
88
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
89
dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir, keep_in_memory=keep_in_memory).read()
90
_check_csv_datasetdict(dataset, expected_features)
93
@pytest.mark.parametrize(
97
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
98
{"col_1": "string", "col_2": "string", "col_3": "string"},
99
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
100
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
103
def test_csv_datasetdict_reader_features(features, csv_path, tmp_path):
104
cache_dir = tmp_path / "cache"
106
default_expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
107
expected_features = features.copy() if features else default_expected_features
109
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
111
dataset = CsvDatasetReader({"train": csv_path}, features=features, cache_dir=cache_dir).read()
112
_check_csv_datasetdict(dataset, expected_features)
115
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
116
def test_csv_datasetdict_reader_split(split, csv_path, tmp_path):
118
path = {split: csv_path}
120
path = {"train": csv_path, "test": csv_path}
121
cache_dir = tmp_path / "cache"
122
expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
123
dataset = CsvDatasetReader(path, cache_dir=cache_dir).read()
124
_check_csv_datasetdict(dataset, expected_features, splits=list(path.keys()))
125
assert all(dataset[split].split == split for split in path.keys())
128
def iter_csv_file(csv_path):
129
with open(csv_path, encoding="utf-8") as csvfile:
130
yield from csv.reader(csvfile)
133
def test_dataset_to_csv(csv_path, tmp_path):
134
cache_dir = tmp_path / "cache"
135
output_csv = os.path.join(cache_dir, "tmp.csv")
136
dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read()
137
CsvDatasetWriter(dataset["train"], output_csv, num_proc=1).write()
139
original_csv = iter_csv_file(csv_path)
140
expected_csv = iter_csv_file(output_csv)
142
for row1, row2 in zip(original_csv, expected_csv):
146
def test_dataset_to_csv_multiproc(csv_path, tmp_path):
147
cache_dir = tmp_path / "cache"
148
output_csv = os.path.join(cache_dir, "tmp.csv")
149
dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read()
150
CsvDatasetWriter(dataset["train"], output_csv, num_proc=2).write()
152
original_csv = iter_csv_file(csv_path)
153
expected_csv = iter_csv_file(output_csv)
155
for row1, row2 in zip(original_csv, expected_csv):
159
def test_dataset_to_csv_invalidproc(csv_path, tmp_path):
160
cache_dir = tmp_path / "cache"
161
output_csv = os.path.join(cache_dir, "tmp.csv")
162
dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read()
163
with pytest.raises(ValueError):
164
CsvDatasetWriter(dataset["train"], output_csv, num_proc=0)