text-generation-inference

Форк
0
150 строк · 5.3 Кб
1
import pytest
2

3
from text_generation import Client, AsyncClient
4
from text_generation.errors import NotFoundError, ValidationError
5
from text_generation.types import FinishReason, InputToken
6

7

8
def test_generate(flan_t5_xxl_url, hf_headers):
9
    client = Client(flan_t5_xxl_url, hf_headers)
10
    response = client.generate("test", max_new_tokens=1, decoder_input_details=True)
11

12
    assert response.generated_text == ""
13
    assert response.details.finish_reason == FinishReason.Length
14
    assert response.details.generated_tokens == 1
15
    assert response.details.seed is None
16
    assert len(response.details.prefill) == 1
17
    assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
18
    assert len(response.details.tokens) == 1
19
    assert response.details.tokens[0].id == 3
20
    assert response.details.tokens[0].text == " "
21
    assert not response.details.tokens[0].special
22

23

24
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
25
    client = Client(flan_t5_xxl_url, hf_headers)
26
    response = client.generate(
27
        "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
28
    )
29

30
    assert response.details.seed is not None
31
    assert response.details.best_of_sequences is not None
32
    assert len(response.details.best_of_sequences) == 1
33
    assert response.details.best_of_sequences[0].seed is not None
34

35

36
def test_generate_not_found(fake_url, hf_headers):
37
    client = Client(fake_url, hf_headers)
38
    with pytest.raises(NotFoundError):
39
        client.generate("test")
40

41

42
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
43
    client = Client(flan_t5_xxl_url, hf_headers)
44
    with pytest.raises(ValidationError):
45
        client.generate("test", max_new_tokens=10_000)
46

47

48
def test_generate_stream(flan_t5_xxl_url, hf_headers):
49
    client = Client(flan_t5_xxl_url, hf_headers)
50
    responses = [
51
        response for response in client.generate_stream("test", max_new_tokens=1)
52
    ]
53

54
    assert len(responses) == 1
55
    response = responses[0]
56

57
    assert response.generated_text == ""
58
    assert response.details.finish_reason == FinishReason.Length
59
    assert response.details.generated_tokens == 1
60
    assert response.details.seed is None
61

62

63
def test_generate_stream_not_found(fake_url, hf_headers):
64
    client = Client(fake_url, hf_headers)
65
    with pytest.raises(NotFoundError):
66
        list(client.generate_stream("test"))
67

68

69
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
70
    client = Client(flan_t5_xxl_url, hf_headers)
71
    with pytest.raises(ValidationError):
72
        list(client.generate_stream("test", max_new_tokens=10_000))
73

74

75
@pytest.mark.asyncio
76
async def test_generate_async(flan_t5_xxl_url, hf_headers):
77
    client = AsyncClient(flan_t5_xxl_url, hf_headers)
78
    response = await client.generate(
79
        "test", max_new_tokens=1, decoder_input_details=True
80
    )
81

82
    assert response.generated_text == ""
83
    assert response.details.finish_reason == FinishReason.Length
84
    assert response.details.generated_tokens == 1
85
    assert response.details.seed is None
86
    assert len(response.details.prefill) == 1
87
    assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
88
    assert len(response.details.tokens) == 1
89
    assert response.details.tokens[0].id == 3
90
    assert response.details.tokens[0].text == " "
91
    assert not response.details.tokens[0].special
92

93

94
@pytest.mark.asyncio
95
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
96
    client = AsyncClient(flan_t5_xxl_url, hf_headers)
97
    response = await client.generate(
98
        "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
99
    )
100

101
    assert response.details.seed is not None
102
    assert response.details.best_of_sequences is not None
103
    assert len(response.details.best_of_sequences) == 1
104
    assert response.details.best_of_sequences[0].seed is not None
105

106

107
@pytest.mark.asyncio
108
async def test_generate_async_not_found(fake_url, hf_headers):
109
    client = AsyncClient(fake_url, hf_headers)
110
    with pytest.raises(NotFoundError):
111
        await client.generate("test")
112

113

114
@pytest.mark.asyncio
115
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
116
    client = AsyncClient(flan_t5_xxl_url, hf_headers)
117
    with pytest.raises(ValidationError):
118
        await client.generate("test", max_new_tokens=10_000)
119

120

121
@pytest.mark.asyncio
122
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
123
    client = AsyncClient(flan_t5_xxl_url, hf_headers)
124
    responses = [
125
        response async for response in client.generate_stream("test", max_new_tokens=1)
126
    ]
127

128
    assert len(responses) == 1
129
    response = responses[0]
130

131
    assert response.generated_text == ""
132
    assert response.details.finish_reason == FinishReason.Length
133
    assert response.details.generated_tokens == 1
134
    assert response.details.seed is None
135

136

137
@pytest.mark.asyncio
138
async def test_generate_stream_async_not_found(fake_url, hf_headers):
139
    client = AsyncClient(fake_url, hf_headers)
140
    with pytest.raises(NotFoundError):
141
        async for _ in client.generate_stream("test"):
142
            pass
143

144

145
@pytest.mark.asyncio
146
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
147
    client = AsyncClient(flan_t5_xxl_url, hf_headers)
148
    with pytest.raises(ValidationError):
149
        async for _ in client.generate_stream("test", max_new_tokens=10_000):
150
            pass
151

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.