7
from datasets import Dataset, Features, Value
8
from datasets.io.sql import SqlDatasetReader, SqlDatasetWriter
10
from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_sqlalchemy
13
def _check_sql_dataset(dataset, expected_features):
14
assert isinstance(dataset, Dataset)
15
assert dataset.num_rows == 4
16
assert dataset.num_columns == 3
17
assert dataset.column_names == ["col_1", "col_2", "col_3"]
18
for feature, expected_dtype in expected_features.items():
19
assert dataset.features[feature].dtype == expected_dtype
23
@pytest.mark.parametrize("keep_in_memory", [False, True])
24
def test_dataset_from_sql_keep_in_memory(keep_in_memory, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
25
cache_dir = tmp_path / "cache"
26
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
27
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
28
dataset = SqlDatasetReader(
29
"dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory
31
_check_sql_dataset(dataset, expected_features)
35
@pytest.mark.parametrize(
39
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
40
{"col_1": "string", "col_2": "string", "col_3": "string"},
41
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
42
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
45
def test_dataset_from_sql_features(features, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
46
cache_dir = tmp_path / "cache"
47
default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
48
expected_features = features.copy() if features else default_expected_features
50
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
52
dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, features=features, cache_dir=cache_dir).read()
53
_check_sql_dataset(dataset, expected_features)
56
def iter_sql_file(sqlite_path):
57
with contextlib.closing(sqlite3.connect(sqlite_path)) as con:
59
cur.execute("SELECT * FROM dataset")
65
def test_dataset_to_sql(sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
66
cache_dir = tmp_path / "cache"
67
output_sqlite_path = os.path.join(cache_dir, "tmp.sql")
68
dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read()
69
SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, num_proc=1).write()
71
original_sql = iter_sql_file(sqlite_path)
72
expected_sql = iter_sql_file(output_sqlite_path)
74
for row1, row2 in zip(original_sql, expected_sql):
79
def test_dataset_to_sql_multiproc(sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
80
cache_dir = tmp_path / "cache"
81
output_sqlite_path = os.path.join(cache_dir, "tmp.sql")
82
dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read()
83
SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, num_proc=2).write()
85
original_sql = iter_sql_file(sqlite_path)
86
expected_sql = iter_sql_file(output_sqlite_path)
88
for row1, row2 in zip(original_sql, expected_sql):
93
def test_dataset_to_sql_invalidproc(sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
94
cache_dir = tmp_path / "cache"
95
output_sqlite_path = os.path.join(cache_dir, "tmp.sql")
96
dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read()
97
with pytest.raises(ValueError):
98
SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, num_proc=0).write()