2
from pathlib import Path
3
from unittest.mock import patch
6
import zstandard as zstd
8
from datasets.download.download_config import DownloadConfig
9
from datasets.utils.file_utils import (
24
Second line of data."""
29
@pytest.fixture(scope="session")
30
def zstd_path(tmp_path_factory):
31
path = tmp_path_factory.mktemp("data") / (FILE_PATH + ".zstd")
32
data = bytes(FILE_CONTENT, "utf-8")
33
with zstd.open(path, "wb") as f:
40
with open(os.path.join(tmpfs.local_root_dir, FILE_PATH), "w") as f:
45
@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
46
def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file):
47
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path}
48
input_path = input_paths[compression_format]
49
cache_dir = tmp_path / "cache"
50
download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True)
51
extracted_path = cached_path(input_path, download_config=download_config)
52
with open(extracted_path) as f:
53
extracted_file_content = f.read()
54
with open(text_file) as f:
55
expected_file_content = f.read()
56
assert extracted_file_content == expected_file_content
59
@pytest.mark.parametrize("default_extracted", [True, False])
60
@pytest.mark.parametrize("default_cache_dir", [True, False])
61
def test_extracted_datasets_path(default_extracted, default_cache_dir, xz_file, tmp_path, monkeypatch):
62
custom_cache_dir = "custom_cache"
63
custom_extracted_dir = "custom_extracted_dir"
64
custom_extracted_path = tmp_path / "custom_extracted_path"
66
expected = ("downloads" if default_cache_dir else custom_cache_dir, "extracted")
68
monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_DIR", custom_extracted_dir)
69
monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_PATH", str(custom_extracted_path))
70
expected = custom_extracted_path.parts[-2:] if default_cache_dir else (custom_cache_dir, custom_extracted_dir)
74
DownloadConfig(extract_compressed_file=True)
76
else DownloadConfig(cache_dir=tmp_path / custom_cache_dir, extract_compressed_file=True)
78
extracted_file_path = cached_path(filename, download_config=download_config)
79
assert Path(extracted_file_path).parent.parts[-2:] == expected
82
def test_cached_path_local(text_file):
84
text_file_abs = str(Path(text_file).resolve())
85
assert os.path.samefile(cached_path(text_file_abs), text_file_abs)
88
text_file_abs = str(Path(text_file).resolve())
89
text_file_rel = str(Path(text_file).resolve().relative_to(Path(os.getcwd())))
90
assert os.path.samefile(cached_path(text_file_rel), text_file_abs)
93
def test_cached_path_missing_local(tmp_path):
95
missing_file = str(tmp_path.resolve() / "__missing_file__.txt")
96
with pytest.raises(FileNotFoundError):
97
cached_path(missing_file)
99
missing_file = "./__missing_file__.txt"
100
with pytest.raises(FileNotFoundError):
101
cached_path(missing_file)
104
def test_get_from_cache_fsspec(tmpfs_file):
105
output_path = get_from_cache(f"tmp://{tmpfs_file}")
106
with open(output_path) as f:
107
output_file_content = f.read()
108
assert output_file_content == FILE_CONTENT
111
@patch("datasets.config.HF_DATASETS_OFFLINE", True)
112
def test_cached_path_offline():
113
with pytest.raises(OfflineModeIsEnabled):
114
cached_path("https://huggingface.co")
117
@patch("datasets.config.HF_DATASETS_OFFLINE", True)
118
def test_http_offline(tmp_path_factory):
119
filename = tmp_path_factory.mktemp("data") / "file.html"
120
with pytest.raises(OfflineModeIsEnabled):
121
http_get("https://huggingface.co", temp_file=filename)
122
with pytest.raises(OfflineModeIsEnabled):
123
http_head("https://huggingface.co")
126
@patch("datasets.config.HF_DATASETS_OFFLINE", True)
127
def test_ftp_offline(tmp_path_factory):
128
filename = tmp_path_factory.mktemp("data") / "file.html"
129
with pytest.raises(OfflineModeIsEnabled):
130
ftp_get("ftp://huggingface.co", temp_file=filename)
131
with pytest.raises(OfflineModeIsEnabled):
132
ftp_head("ftp://huggingface.co")
135
@patch("datasets.config.HF_DATASETS_OFFLINE", True)
136
def test_fsspec_offline(tmp_path_factory):
137
filename = tmp_path_factory.mktemp("data") / "file.html"
138
with pytest.raises(OfflineModeIsEnabled):
139
fsspec_get("s3://huggingface.co", temp_file=filename)
140
with pytest.raises(OfflineModeIsEnabled):
141
fsspec_head("s3://huggingface.co")