text-generation-inference

Форк
0
105 строк · 3.1 Кб
1
import os
2
import requests
3
import tempfile
4

5
import pytest
6

7
import huggingface_hub.constants
8
from huggingface_hub import hf_api
9

10
import text_generation_server.utils.hub
11
from text_generation_server.utils.hub import (
12
    weight_hub_files,
13
    download_weights,
14
    weight_files,
15
    EntryNotFoundError,
16
    LocalEntryNotFoundError,
17
    RevisionNotFoundError,
18
)
19

20

21
@pytest.fixture()
22
def offline():
23
    current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
24
    text_generation_server.utils.hub.HF_HUB_OFFLINE = True
25
    yield "offline"
26
    text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value
27

28

29
@pytest.fixture()
30
def fresh_cache():
31
    with tempfile.TemporaryDirectory() as d:
32
        current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
33
        huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
34
        text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
35
        os.environ["HUGGINGFACE_HUB_CACHE"] = d
36
        yield
37
        huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
38
        os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
39
        text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
40

41

42
@pytest.fixture()
43
def prefetched():
44
    model_id = "bert-base-uncased"
45
    huggingface_hub.snapshot_download(
46
        repo_id=model_id,
47
        revision="main",
48
        local_files_only=False,
49
        repo_type="model",
50
        allow_patterns=["*.safetensors"],
51
    )
52
    yield model_id
53

54

55
def test_weight_hub_files_offline_error(offline, fresh_cache):
56
    # If the model is not prefetched then it will raise an error
57
    with pytest.raises(EntryNotFoundError):
58
        weight_hub_files("gpt2")
59

60

61
def test_weight_hub_files_offline_ok(prefetched, offline):
62
    # If the model is prefetched then we should be able to get the weight files from local cache
63
    filenames = weight_hub_files(prefetched)
64
    root = None
65
    assert len(filenames) == 1
66
    for f in filenames:
67
        curroot, filename = os.path.split(f)
68
        if root is None:
69
            root = curroot
70
        else:
71
            assert root == curroot
72
        assert filename == "model.safetensors"
73

74

75
def test_weight_hub_files():
76
    filenames = weight_hub_files("bigscience/bloom-560m")
77
    assert filenames == ["model.safetensors"]
78

79

80
def test_weight_hub_files_llm():
81
    filenames = weight_hub_files("bigscience/bloom")
82
    assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
83

84

85
def test_weight_hub_files_empty():
86
    with pytest.raises(EntryNotFoundError):
87
        weight_hub_files("bigscience/bloom", extension=".errors")
88

89

90
def test_download_weights():
91
    model_id = "bigscience/bloom-560m"
92
    filenames = weight_hub_files(model_id)
93
    files = download_weights(filenames, model_id)
94
    local_files = weight_files("bigscience/bloom-560m")
95
    assert files == local_files
96

97

98
def test_weight_files_revision_error():
99
    with pytest.raises(RevisionNotFoundError):
100
        weight_files("bigscience/bloom-560m", revision="error")
101

102

103
def test_weight_files_not_cached_error(fresh_cache):
104
    with pytest.raises(LocalEntryNotFoundError):
105
        weight_files("bert-base-uncased")
106

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.