text-generation-inference
105 строк · 3.1 Кб
1import os
2import requests
3import tempfile
4
5import pytest
6
7import huggingface_hub.constants
8from huggingface_hub import hf_api
9
10import text_generation_server.utils.hub
11from text_generation_server.utils.hub import (
12weight_hub_files,
13download_weights,
14weight_files,
15EntryNotFoundError,
16LocalEntryNotFoundError,
17RevisionNotFoundError,
18)
19
20
21@pytest.fixture()
22def offline():
23current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
24text_generation_server.utils.hub.HF_HUB_OFFLINE = True
25yield "offline"
26text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value
27
28
29@pytest.fixture()
30def fresh_cache():
31with tempfile.TemporaryDirectory() as d:
32current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
33huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
34text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
35os.environ["HUGGINGFACE_HUB_CACHE"] = d
36yield
37huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
38os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
39text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
40
41
42@pytest.fixture()
43def prefetched():
44model_id = "bert-base-uncased"
45huggingface_hub.snapshot_download(
46repo_id=model_id,
47revision="main",
48local_files_only=False,
49repo_type="model",
50allow_patterns=["*.safetensors"],
51)
52yield model_id
53
54
55def test_weight_hub_files_offline_error(offline, fresh_cache):
56# If the model is not prefetched then it will raise an error
57with pytest.raises(EntryNotFoundError):
58weight_hub_files("gpt2")
59
60
61def test_weight_hub_files_offline_ok(prefetched, offline):
62# If the model is prefetched then we should be able to get the weight files from local cache
63filenames = weight_hub_files(prefetched)
64root = None
65assert len(filenames) == 1
66for f in filenames:
67curroot, filename = os.path.split(f)
68if root is None:
69root = curroot
70else:
71assert root == curroot
72assert filename == "model.safetensors"
73
74
75def test_weight_hub_files():
76filenames = weight_hub_files("bigscience/bloom-560m")
77assert filenames == ["model.safetensors"]
78
79
80def test_weight_hub_files_llm():
81filenames = weight_hub_files("bigscience/bloom")
82assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
83
84
85def test_weight_hub_files_empty():
86with pytest.raises(EntryNotFoundError):
87weight_hub_files("bigscience/bloom", extension=".errors")
88
89
90def test_download_weights():
91model_id = "bigscience/bloom-560m"
92filenames = weight_hub_files(model_id)
93files = download_weights(filenames, model_id)
94local_files = weight_files("bigscience/bloom-560m")
95assert files == local_files
96
97
98def test_weight_files_revision_error():
99with pytest.raises(RevisionNotFoundError):
100weight_files("bigscience/bloom-560m", revision="error")
101
102
103def test_weight_files_not_cached_error(fresh_cache):
104with pytest.raises(LocalEntryNotFoundError):
105weight_files("bert-base-uncased")
106