datasets

Форк
0
/
test_sql.py 
98 строк · 4.0 Кб
1
import contextlib
2
import os
3
import sqlite3
4

5
import pytest
6

7
from datasets import Dataset, Features, Value
8
from datasets.io.sql import SqlDatasetReader, SqlDatasetWriter
9

10
from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_sqlalchemy
11

12

13
def _check_sql_dataset(dataset, expected_features):
14
    assert isinstance(dataset, Dataset)
15
    assert dataset.num_rows == 4
16
    assert dataset.num_columns == 3
17
    assert dataset.column_names == ["col_1", "col_2", "col_3"]
18
    for feature, expected_dtype in expected_features.items():
19
        assert dataset.features[feature].dtype == expected_dtype
20

21

22
@require_sqlalchemy
23
@pytest.mark.parametrize("keep_in_memory", [False, True])
24
def test_dataset_from_sql_keep_in_memory(keep_in_memory, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
25
    cache_dir = tmp_path / "cache"
26
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
27
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
28
        dataset = SqlDatasetReader(
29
            "dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory
30
        ).read()
31
    _check_sql_dataset(dataset, expected_features)
32

33

34
@require_sqlalchemy
35
@pytest.mark.parametrize(
36
    "features",
37
    [
38
        None,
39
        {"col_1": "string", "col_2": "int64", "col_3": "float64"},
40
        {"col_1": "string", "col_2": "string", "col_3": "string"},
41
        {"col_1": "int32", "col_2": "int32", "col_3": "int32"},
42
        {"col_1": "float32", "col_2": "float32", "col_3": "float32"},
43
    ],
44
)
45
def test_dataset_from_sql_features(features, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
46
    cache_dir = tmp_path / "cache"
47
    default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
48
    expected_features = features.copy() if features else default_expected_features
49
    features = (
50
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
51
    )
52
    dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, features=features, cache_dir=cache_dir).read()
53
    _check_sql_dataset(dataset, expected_features)
54

55

56
def iter_sql_file(sqlite_path):
57
    with contextlib.closing(sqlite3.connect(sqlite_path)) as con:
58
        cur = con.cursor()
59
        cur.execute("SELECT * FROM dataset")
60
        for row in cur:
61
            yield row
62

63

64
@require_sqlalchemy
65
def test_dataset_to_sql(sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
66
    cache_dir = tmp_path / "cache"
67
    output_sqlite_path = os.path.join(cache_dir, "tmp.sql")
68
    dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read()
69
    SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, num_proc=1).write()
70

71
    original_sql = iter_sql_file(sqlite_path)
72
    expected_sql = iter_sql_file(output_sqlite_path)
73

74
    for row1, row2 in zip(original_sql, expected_sql):
75
        assert row1 == row2
76

77

78
@require_sqlalchemy
79
def test_dataset_to_sql_multiproc(sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
80
    cache_dir = tmp_path / "cache"
81
    output_sqlite_path = os.path.join(cache_dir, "tmp.sql")
82
    dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read()
83
    SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, num_proc=2).write()
84

85
    original_sql = iter_sql_file(sqlite_path)
86
    expected_sql = iter_sql_file(output_sqlite_path)
87

88
    for row1, row2 in zip(original_sql, expected_sql):
89
        assert row1 == row2
90

91

92
@require_sqlalchemy
93
def test_dataset_to_sql_invalidproc(sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
94
    cache_dir = tmp_path / "cache"
95
    output_sqlite_path = os.path.join(cache_dir, "tmp.sql")
96
    dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read()
97
    with pytest.raises(ValueError):
98
        SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, num_proc=0).write()
99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.