text-generation-inference
150 строк · 5.3 Кб
1import pytest
2
3from text_generation import Client, AsyncClient
4from text_generation.errors import NotFoundError, ValidationError
5from text_generation.types import FinishReason, InputToken
6
7
8def test_generate(flan_t5_xxl_url, hf_headers):
9client = Client(flan_t5_xxl_url, hf_headers)
10response = client.generate("test", max_new_tokens=1, decoder_input_details=True)
11
12assert response.generated_text == ""
13assert response.details.finish_reason == FinishReason.Length
14assert response.details.generated_tokens == 1
15assert response.details.seed is None
16assert len(response.details.prefill) == 1
17assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
18assert len(response.details.tokens) == 1
19assert response.details.tokens[0].id == 3
20assert response.details.tokens[0].text == " "
21assert not response.details.tokens[0].special
22
23
24def test_generate_best_of(flan_t5_xxl_url, hf_headers):
25client = Client(flan_t5_xxl_url, hf_headers)
26response = client.generate(
27"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
28)
29
30assert response.details.seed is not None
31assert response.details.best_of_sequences is not None
32assert len(response.details.best_of_sequences) == 1
33assert response.details.best_of_sequences[0].seed is not None
34
35
36def test_generate_not_found(fake_url, hf_headers):
37client = Client(fake_url, hf_headers)
38with pytest.raises(NotFoundError):
39client.generate("test")
40
41
42def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
43client = Client(flan_t5_xxl_url, hf_headers)
44with pytest.raises(ValidationError):
45client.generate("test", max_new_tokens=10_000)
46
47
48def test_generate_stream(flan_t5_xxl_url, hf_headers):
49client = Client(flan_t5_xxl_url, hf_headers)
50responses = [
51response for response in client.generate_stream("test", max_new_tokens=1)
52]
53
54assert len(responses) == 1
55response = responses[0]
56
57assert response.generated_text == ""
58assert response.details.finish_reason == FinishReason.Length
59assert response.details.generated_tokens == 1
60assert response.details.seed is None
61
62
63def test_generate_stream_not_found(fake_url, hf_headers):
64client = Client(fake_url, hf_headers)
65with pytest.raises(NotFoundError):
66list(client.generate_stream("test"))
67
68
69def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
70client = Client(flan_t5_xxl_url, hf_headers)
71with pytest.raises(ValidationError):
72list(client.generate_stream("test", max_new_tokens=10_000))
73
74
75@pytest.mark.asyncio
76async def test_generate_async(flan_t5_xxl_url, hf_headers):
77client = AsyncClient(flan_t5_xxl_url, hf_headers)
78response = await client.generate(
79"test", max_new_tokens=1, decoder_input_details=True
80)
81
82assert response.generated_text == ""
83assert response.details.finish_reason == FinishReason.Length
84assert response.details.generated_tokens == 1
85assert response.details.seed is None
86assert len(response.details.prefill) == 1
87assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
88assert len(response.details.tokens) == 1
89assert response.details.tokens[0].id == 3
90assert response.details.tokens[0].text == " "
91assert not response.details.tokens[0].special
92
93
94@pytest.mark.asyncio
95async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
96client = AsyncClient(flan_t5_xxl_url, hf_headers)
97response = await client.generate(
98"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
99)
100
101assert response.details.seed is not None
102assert response.details.best_of_sequences is not None
103assert len(response.details.best_of_sequences) == 1
104assert response.details.best_of_sequences[0].seed is not None
105
106
107@pytest.mark.asyncio
108async def test_generate_async_not_found(fake_url, hf_headers):
109client = AsyncClient(fake_url, hf_headers)
110with pytest.raises(NotFoundError):
111await client.generate("test")
112
113
114@pytest.mark.asyncio
115async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
116client = AsyncClient(flan_t5_xxl_url, hf_headers)
117with pytest.raises(ValidationError):
118await client.generate("test", max_new_tokens=10_000)
119
120
121@pytest.mark.asyncio
122async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
123client = AsyncClient(flan_t5_xxl_url, hf_headers)
124responses = [
125response async for response in client.generate_stream("test", max_new_tokens=1)
126]
127
128assert len(responses) == 1
129response = responses[0]
130
131assert response.generated_text == ""
132assert response.details.finish_reason == FinishReason.Length
133assert response.details.generated_tokens == 1
134assert response.details.seed is None
135
136
137@pytest.mark.asyncio
138async def test_generate_stream_async_not_found(fake_url, hf_headers):
139client = AsyncClient(fake_url, hf_headers)
140with pytest.raises(NotFoundError):
141async for _ in client.generate_stream("test"):
142pass
143
144
145@pytest.mark.asyncio
146async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
147client = AsyncClient(flan_t5_xxl_url, hf_headers)
148with pytest.raises(ValidationError):
149async for _ in client.generate_stream("test", max_new_tokens=10_000):
150pass
151