datasets

Форк
0
/
test_csv.py 
164 строки · 6.9 Кб
1
import csv
2
import os
3

4
import pytest
5

6
from datasets import Dataset, DatasetDict, Features, NamedSplit, Value
7
from datasets.io.csv import CsvDatasetReader, CsvDatasetWriter
8

9
from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases
10

11

12
def _check_csv_dataset(dataset, expected_features):
13
    assert isinstance(dataset, Dataset)
14
    assert dataset.num_rows == 4
15
    assert dataset.num_columns == 3
16
    assert dataset.column_names == ["col_1", "col_2", "col_3"]
17
    for feature, expected_dtype in expected_features.items():
18
        assert dataset.features[feature].dtype == expected_dtype
19

20

21
@pytest.mark.parametrize("keep_in_memory", [False, True])
22
def test_dataset_from_csv_keep_in_memory(keep_in_memory, csv_path, tmp_path):
23
    cache_dir = tmp_path / "cache"
24
    expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
25
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
26
        dataset = CsvDatasetReader(csv_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory).read()
27
    _check_csv_dataset(dataset, expected_features)
28

29

30
@pytest.mark.parametrize(
31
    "features",
32
    [
33
        None,
34
        {"col_1": "string", "col_2": "int64", "col_3": "float64"},
35
        {"col_1": "string", "col_2": "string", "col_3": "string"},
36
        {"col_1": "int32", "col_2": "int32", "col_3": "int32"},
37
        {"col_1": "float32", "col_2": "float32", "col_3": "float32"},
38
    ],
39
)
40
def test_dataset_from_csv_features(features, csv_path, tmp_path):
41
    cache_dir = tmp_path / "cache"
42
    # CSV file loses col_1 string dtype information: default now is "int64" instead of "string"
43
    default_expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
44
    expected_features = features.copy() if features else default_expected_features
45
    features = (
46
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
47
    )
48
    dataset = CsvDatasetReader(csv_path, features=features, cache_dir=cache_dir).read()
49
    _check_csv_dataset(dataset, expected_features)
50

51

52
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
53
def test_dataset_from_csv_split(split, csv_path, tmp_path):
54
    cache_dir = tmp_path / "cache"
55
    expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
56
    dataset = CsvDatasetReader(csv_path, cache_dir=cache_dir, split=split).read()
57
    _check_csv_dataset(dataset, expected_features)
58
    assert dataset.split == split if split else "train"
59

60

61
@pytest.mark.parametrize("path_type", [str, list])
62
def test_dataset_from_csv_path_type(path_type, csv_path, tmp_path):
63
    if issubclass(path_type, str):
64
        path = csv_path
65
    elif issubclass(path_type, list):
66
        path = [csv_path]
67
    cache_dir = tmp_path / "cache"
68
    expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
69
    dataset = CsvDatasetReader(path, cache_dir=cache_dir).read()
70
    _check_csv_dataset(dataset, expected_features)
71

72

73
def _check_csv_datasetdict(dataset_dict, expected_features, splits=("train",)):
74
    assert isinstance(dataset_dict, DatasetDict)
75
    for split in splits:
76
        dataset = dataset_dict[split]
77
        assert dataset.num_rows == 4
78
        assert dataset.num_columns == 3
79
        assert dataset.column_names == ["col_1", "col_2", "col_3"]
80
        for feature, expected_dtype in expected_features.items():
81
            assert dataset.features[feature].dtype == expected_dtype
82

83

84
@pytest.mark.parametrize("keep_in_memory", [False, True])
85
def test_csv_datasetdict_reader_keep_in_memory(keep_in_memory, csv_path, tmp_path):
86
    cache_dir = tmp_path / "cache"
87
    expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
88
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
89
        dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir, keep_in_memory=keep_in_memory).read()
90
    _check_csv_datasetdict(dataset, expected_features)
91

92

93
@pytest.mark.parametrize(
94
    "features",
95
    [
96
        None,
97
        {"col_1": "string", "col_2": "int64", "col_3": "float64"},
98
        {"col_1": "string", "col_2": "string", "col_3": "string"},
99
        {"col_1": "int32", "col_2": "int32", "col_3": "int32"},
100
        {"col_1": "float32", "col_2": "float32", "col_3": "float32"},
101
    ],
102
)
103
def test_csv_datasetdict_reader_features(features, csv_path, tmp_path):
104
    cache_dir = tmp_path / "cache"
105
    # CSV file loses col_1 string dtype information: default now is "int64" instead of "string"
106
    default_expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
107
    expected_features = features.copy() if features else default_expected_features
108
    features = (
109
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
110
    )
111
    dataset = CsvDatasetReader({"train": csv_path}, features=features, cache_dir=cache_dir).read()
112
    _check_csv_datasetdict(dataset, expected_features)
113

114

115
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
116
def test_csv_datasetdict_reader_split(split, csv_path, tmp_path):
117
    if split:
118
        path = {split: csv_path}
119
    else:
120
        path = {"train": csv_path, "test": csv_path}
121
    cache_dir = tmp_path / "cache"
122
    expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
123
    dataset = CsvDatasetReader(path, cache_dir=cache_dir).read()
124
    _check_csv_datasetdict(dataset, expected_features, splits=list(path.keys()))
125
    assert all(dataset[split].split == split for split in path.keys())
126

127

128
def iter_csv_file(csv_path):
129
    with open(csv_path, encoding="utf-8") as csvfile:
130
        yield from csv.reader(csvfile)
131

132

133
def test_dataset_to_csv(csv_path, tmp_path):
134
    cache_dir = tmp_path / "cache"
135
    output_csv = os.path.join(cache_dir, "tmp.csv")
136
    dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read()
137
    CsvDatasetWriter(dataset["train"], output_csv, num_proc=1).write()
138

139
    original_csv = iter_csv_file(csv_path)
140
    expected_csv = iter_csv_file(output_csv)
141

142
    for row1, row2 in zip(original_csv, expected_csv):
143
        assert row1 == row2
144

145

146
def test_dataset_to_csv_multiproc(csv_path, tmp_path):
147
    cache_dir = tmp_path / "cache"
148
    output_csv = os.path.join(cache_dir, "tmp.csv")
149
    dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read()
150
    CsvDatasetWriter(dataset["train"], output_csv, num_proc=2).write()
151

152
    original_csv = iter_csv_file(csv_path)
153
    expected_csv = iter_csv_file(output_csv)
154

155
    for row1, row2 in zip(original_csv, expected_csv):
156
        assert row1 == row2
157

158

159
def test_dataset_to_csv_invalidproc(csv_path, tmp_path):
160
    cache_dir = tmp_path / "cache"
161
    output_csv = os.path.join(cache_dir, "tmp.csv")
162
    dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read()
163
    with pytest.raises(ValueError):
164
        CsvDatasetWriter(dataset["train"], output_csv, num_proc=0)
165

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.