datasets

Форк
0
/
test_file_utils.py 
141 строка · 4.9 Кб
1
import os
2
from pathlib import Path
3
from unittest.mock import patch
4

5
import pytest
6
import zstandard as zstd
7

8
from datasets.download.download_config import DownloadConfig
9
from datasets.utils.file_utils import (
10
    OfflineModeIsEnabled,
11
    cached_path,
12
    fsspec_get,
13
    fsspec_head,
14
    ftp_get,
15
    ftp_head,
16
    get_from_cache,
17
    http_get,
18
    http_head,
19
)
20

21

22
FILE_CONTENT = """\
23
    Text data.
24
    Second line of data."""
25

26
FILE_PATH = "file"
27

28

29
@pytest.fixture(scope="session")
30
def zstd_path(tmp_path_factory):
31
    path = tmp_path_factory.mktemp("data") / (FILE_PATH + ".zstd")
32
    data = bytes(FILE_CONTENT, "utf-8")
33
    with zstd.open(path, "wb") as f:
34
        f.write(data)
35
    return path
36

37

38
@pytest.fixture
39
def tmpfs_file(tmpfs):
40
    with open(os.path.join(tmpfs.local_root_dir, FILE_PATH), "w") as f:
41
        f.write(FILE_CONTENT)
42
    return FILE_PATH
43

44

45
@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
46
def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file):
47
    input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path}
48
    input_path = input_paths[compression_format]
49
    cache_dir = tmp_path / "cache"
50
    download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True)
51
    extracted_path = cached_path(input_path, download_config=download_config)
52
    with open(extracted_path) as f:
53
        extracted_file_content = f.read()
54
    with open(text_file) as f:
55
        expected_file_content = f.read()
56
    assert extracted_file_content == expected_file_content
57

58

59
@pytest.mark.parametrize("default_extracted", [True, False])
60
@pytest.mark.parametrize("default_cache_dir", [True, False])
61
def test_extracted_datasets_path(default_extracted, default_cache_dir, xz_file, tmp_path, monkeypatch):
62
    custom_cache_dir = "custom_cache"
63
    custom_extracted_dir = "custom_extracted_dir"
64
    custom_extracted_path = tmp_path / "custom_extracted_path"
65
    if default_extracted:
66
        expected = ("downloads" if default_cache_dir else custom_cache_dir, "extracted")
67
    else:
68
        monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_DIR", custom_extracted_dir)
69
        monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_PATH", str(custom_extracted_path))
70
        expected = custom_extracted_path.parts[-2:] if default_cache_dir else (custom_cache_dir, custom_extracted_dir)
71

72
    filename = xz_file
73
    download_config = (
74
        DownloadConfig(extract_compressed_file=True)
75
        if default_cache_dir
76
        else DownloadConfig(cache_dir=tmp_path / custom_cache_dir, extract_compressed_file=True)
77
    )
78
    extracted_file_path = cached_path(filename, download_config=download_config)
79
    assert Path(extracted_file_path).parent.parts[-2:] == expected
80

81

82
def test_cached_path_local(text_file):
83
    # input absolute path -> output absolute path
84
    text_file_abs = str(Path(text_file).resolve())
85
    assert os.path.samefile(cached_path(text_file_abs), text_file_abs)
86
    # input relative path -> output absolute path
87
    text_file = __file__
88
    text_file_abs = str(Path(text_file).resolve())
89
    text_file_rel = str(Path(text_file).resolve().relative_to(Path(os.getcwd())))
90
    assert os.path.samefile(cached_path(text_file_rel), text_file_abs)
91

92

93
def test_cached_path_missing_local(tmp_path):
94
    # absolute path
95
    missing_file = str(tmp_path.resolve() / "__missing_file__.txt")
96
    with pytest.raises(FileNotFoundError):
97
        cached_path(missing_file)
98
    # relative path
99
    missing_file = "./__missing_file__.txt"
100
    with pytest.raises(FileNotFoundError):
101
        cached_path(missing_file)
102

103

104
def test_get_from_cache_fsspec(tmpfs_file):
105
    output_path = get_from_cache(f"tmp://{tmpfs_file}")
106
    with open(output_path) as f:
107
        output_file_content = f.read()
108
    assert output_file_content == FILE_CONTENT
109

110

111
@patch("datasets.config.HF_DATASETS_OFFLINE", True)
112
def test_cached_path_offline():
113
    with pytest.raises(OfflineModeIsEnabled):
114
        cached_path("https://huggingface.co")
115

116

117
@patch("datasets.config.HF_DATASETS_OFFLINE", True)
118
def test_http_offline(tmp_path_factory):
119
    filename = tmp_path_factory.mktemp("data") / "file.html"
120
    with pytest.raises(OfflineModeIsEnabled):
121
        http_get("https://huggingface.co", temp_file=filename)
122
    with pytest.raises(OfflineModeIsEnabled):
123
        http_head("https://huggingface.co")
124

125

126
@patch("datasets.config.HF_DATASETS_OFFLINE", True)
127
def test_ftp_offline(tmp_path_factory):
128
    filename = tmp_path_factory.mktemp("data") / "file.html"
129
    with pytest.raises(OfflineModeIsEnabled):
130
        ftp_get("ftp://huggingface.co", temp_file=filename)
131
    with pytest.raises(OfflineModeIsEnabled):
132
        ftp_head("ftp://huggingface.co")
133

134

135
@patch("datasets.config.HF_DATASETS_OFFLINE", True)
136
def test_fsspec_offline(tmp_path_factory):
137
    filename = tmp_path_factory.mktemp("data") / "file.html"
138
    with pytest.raises(OfflineModeIsEnabled):
139
        fsspec_get("s3://huggingface.co", temp_file=filename)
140
    with pytest.raises(OfflineModeIsEnabled):
141
        fsspec_head("s3://huggingface.co")
142

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.