litellm

Форк
0
/
test_openai_endpoints.py 
433 строки · 12.4 Кб
1
# What this tests ?
2
## Tests /chat/completions by generating a key and then making a chat completions request
3
import pytest
4
import asyncio
5
import aiohttp, openai
6
from openai import OpenAI, AsyncOpenAI
7
from typing import Optional, List, Union
8

9

10
def response_header_check(response):
11
    """
12
    - assert if response headers < 4kb (nginx limit).
13
    """
14
    headers_size = sum(len(k) + len(v) for k, v in response.raw_headers)
15
    assert headers_size < 4096, "Response headers exceed the 4kb limit"
16

17

18
async def generate_key(
19
    session,
20
    models=[
21
        "gpt-4",
22
        "text-embedding-ada-002",
23
        "dall-e-2",
24
        "fake-openai-endpoint-2",
25
    ],
26
):
27
    url = "http://0.0.0.0:4000/key/generate"
28
    headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
29
    data = {
30
        "models": models,
31
        "duration": None,
32
    }
33

34
    async with session.post(url, headers=headers, json=data) as response:
35
        status = response.status
36
        response_text = await response.text()
37

38
        print(response_text)
39
        print()
40

41
        if status != 200:
42
            raise Exception(f"Request did not return a 200 status code: {status}")
43

44
        response_header_check(
45
            response
46
        )  # calling the function to check response headers
47

48
        return await response.json()
49

50

51
async def new_user(session):
52
    url = "http://0.0.0.0:4000/user/new"
53
    headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
54
    data = {
55
        "models": ["gpt-4", "text-embedding-ada-002", "dall-e-2"],
56
        "duration": None,
57
    }
58

59
    async with session.post(url, headers=headers, json=data) as response:
60
        status = response.status
61
        response_text = await response.text()
62

63
        print(response_text)
64
        print()
65

66
        if status != 200:
67
            raise Exception(f"Request did not return a 200 status code: {status}")
68

69
        response_header_check(
70
            response
71
        )  # calling the function to check response headers
72
        return await response.json()
73

74

75
async def chat_completion(session, key, model: Union[str, List] = "gpt-4"):
76
    url = "http://0.0.0.0:4000/chat/completions"
77
    headers = {
78
        "Authorization": f"Bearer {key}",
79
        "Content-Type": "application/json",
80
    }
81
    data = {
82
        "model": model,
83
        "messages": [
84
            {"role": "system", "content": "You are a helpful assistant."},
85
            {"role": "user", "content": "Hello!"},
86
        ],
87
    }
88

89
    async with session.post(url, headers=headers, json=data) as response:
90
        status = response.status
91
        response_text = await response.text()
92

93
        print(response_text)
94
        print()
95

96
        if status != 200:
97
            raise Exception(f"Request did not return a 200 status code: {status}")
98

99
        response_header_check(
100
            response
101
        )  # calling the function to check response headers
102

103
        return await response.json()
104

105

106
async def chat_completion_with_headers(session, key, model="gpt-4"):
107
    url = "http://0.0.0.0:4000/chat/completions"
108
    headers = {
109
        "Authorization": f"Bearer {key}",
110
        "Content-Type": "application/json",
111
    }
112
    data = {
113
        "model": model,
114
        "messages": [
115
            {"role": "system", "content": "You are a helpful assistant."},
116
            {"role": "user", "content": "Hello!"},
117
        ],
118
    }
119

120
    async with session.post(url, headers=headers, json=data) as response:
121
        status = response.status
122
        response_text = await response.text()
123

124
        print(response_text)
125
        print()
126

127
        if status != 200:
128
            raise Exception(f"Request did not return a 200 status code: {status}")
129

130
        response_header_check(
131
            response
132
        )  # calling the function to check response headers
133

134
        raw_headers = response.raw_headers
135
        raw_headers_json = {}
136

137
        for (
138
            item
139
        ) in (
140
            response.raw_headers
141
        ):  # ((b'date', b'Fri, 19 Apr 2024 21:17:29 GMT'), (), )
142
            raw_headers_json[item[0].decode("utf-8")] = item[1].decode("utf-8")
143

144
        return raw_headers_json
145

146

147
async def completion(session, key):
148
    url = "http://0.0.0.0:4000/completions"
149
    headers = {
150
        "Authorization": f"Bearer {key}",
151
        "Content-Type": "application/json",
152
    }
153
    data = {"model": "gpt-4", "prompt": "Hello!"}
154

155
    async with session.post(url, headers=headers, json=data) as response:
156
        status = response.status
157

158
        if status != 200:
159
            raise Exception(f"Request did not return a 200 status code: {status}")
160

161
        response_header_check(
162
            response
163
        )  # calling the function to check response headers
164

165
        response = await response.json()
166

167
        return response
168

169

170
async def embeddings(session, key):
171
    url = "http://0.0.0.0:4000/embeddings"
172
    headers = {
173
        "Authorization": f"Bearer {key}",
174
        "Content-Type": "application/json",
175
    }
176
    data = {
177
        "model": "text-embedding-ada-002",
178
        "input": ["hello world"],
179
    }
180

181
    async with session.post(url, headers=headers, json=data) as response:
182
        status = response.status
183
        response_text = await response.text()
184

185
        print(response_text)
186

187
        if status != 200:
188
            raise Exception(f"Request did not return a 200 status code: {status}")
189

190
        response_header_check(
191
            response
192
        )  # calling the function to check response headers
193

194

195
async def image_generation(session, key):
196
    url = "http://0.0.0.0:4000/images/generations"
197
    headers = {
198
        "Authorization": f"Bearer {key}",
199
        "Content-Type": "application/json",
200
    }
201
    data = {
202
        "model": "dall-e-2",
203
        "prompt": "A cute baby sea otter",
204
    }
205

206
    async with session.post(url, headers=headers, json=data) as response:
207
        status = response.status
208
        response_text = await response.text()
209

210
        print(response_text)
211
        print()
212

213
        if status != 200:
214
            if (
215
                "Connection error" in response_text
216
            ):  # OpenAI endpoint returns a connection error
217
                return
218
            raise Exception(f"Request did not return a 200 status code: {status}")
219

220
        response_header_check(
221
            response
222
        )  # calling the function to check response headers
223

224

225
@pytest.mark.asyncio
226
async def test_chat_completion():
227
    """
228
    - Create key
229
    Make chat completion call
230
    - Create user
231
    make chat completion call
232
    """
233
    async with aiohttp.ClientSession() as session:
234
        key_gen = await generate_key(session=session)
235
        key = key_gen["key"]
236
        await chat_completion(session=session, key=key)
237
        key_gen = await new_user(session=session)
238
        key_2 = key_gen["key"]
239
        await chat_completion(session=session, key=key_2)
240

241

242
# @pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
243
@pytest.mark.asyncio
244
async def test_chat_completion_ratelimit():
245
    """
246
    - call model with rpm 1
247
    - make 2 parallel calls
248
    - make sure 1 fails
249
    """
250
    async with aiohttp.ClientSession() as session:
251
        # key_gen = await generate_key(session=session)
252
        key = "sk-1234"
253
        tasks = []
254
        tasks.append(
255
            chat_completion(session=session, key=key, model="fake-openai-endpoint-2")
256
        )
257
        tasks.append(
258
            chat_completion(session=session, key=key, model="fake-openai-endpoint-2")
259
        )
260
        try:
261
            await asyncio.gather(*tasks)
262
            pytest.fail("Expected at least 1 call to fail")
263
        except Exception as e:
264
            if "Request did not return a 200 status code: 429" in str(e):
265
                pass
266
            else:
267
                pytest.fail(f"Wrong error received - {str(e)}")
268

269

270
@pytest.mark.asyncio
271
async def test_chat_completion_different_deployments():
272
    """
273
    - call model group with 2 deployments
274
    - make 5 calls
275
    - expect 2 unique deployments
276
    """
277
    async with aiohttp.ClientSession() as session:
278
        # key_gen = await generate_key(session=session)
279
        key = "sk-1234"
280
        results = []
281
        for _ in range(5):
282
            results.append(
283
                await chat_completion_with_headers(
284
                    session=session, key=key, model="fake-openai-endpoint-3"
285
                )
286
            )
287
        try:
288
            print(f"results: {results}")
289
            init_model_id = results[0]["x-litellm-model-id"]
290
            deployments_shuffled = False
291
            for result in results[1:]:
292
                if init_model_id != result["x-litellm-model-id"]:
293
                    deployments_shuffled = True
294
            if deployments_shuffled == False:
295
                pytest.fail("Expected at least 1 shuffled call")
296
        except Exception as e:
297
            pass
298

299

300
@pytest.mark.asyncio
301
async def test_chat_completion_streaming():
302
    """
303
    [PROD Test] Ensures logprobs are returned correctly
304
    """
305
    client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
306

307
    response = await client.chat.completions.create(
308
        model="gpt-3.5-turbo-large",
309
        messages=[{"role": "user", "content": "Hello!"}],
310
        logprobs=True,
311
        top_logprobs=2,
312
        stream=True,
313
    )
314

315
    response_str = ""
316

317
    async for chunk in response:
318
        response_str += chunk.choices[0].delta.content or ""
319

320
    print(f"response_str: {response_str}")
321

322

323
@pytest.mark.asyncio
324
async def test_chat_completion_old_key():
325
    """
326
    Production test for backwards compatibility. Test db against a pre-generated (old key)
327
    - Create key
328
    Make chat completion call
329
    """
330
    async with aiohttp.ClientSession() as session:
331
        try:
332
            key = "sk--W0Ph0uDZLVD7V7LQVrslg"
333
            await chat_completion(session=session, key=key)
334
        except Exception as e:
335
            pytest.fail("Invalid api key")
336

337

338
@pytest.mark.asyncio
339
async def test_completion():
340
    """
341
    - Create key
342
    Make chat completion call
343
    - Create user
344
    make chat completion call
345
    """
346
    async with aiohttp.ClientSession() as session:
347
        key_gen = await generate_key(session=session)
348
        key = key_gen["key"]
349
        await completion(session=session, key=key)
350
        key_gen = await new_user(session=session)
351
        key_2 = key_gen["key"]
352
        # response = await completion(session=session, key=key_2)
353

354
    ## validate openai format ##
355
    client = OpenAI(api_key=key_2, base_url="http://0.0.0.0:4000")
356

357
    client.completions.create(
358
        model="gpt-4",
359
        prompt="Say this is a test",
360
        max_tokens=7,
361
        temperature=0,
362
    )
363

364

365
@pytest.mark.asyncio
366
async def test_embeddings():
367
    """
368
    - Create key
369
    Make embeddings call
370
    - Create user
371
    make embeddings call
372
    """
373
    async with aiohttp.ClientSession() as session:
374
        key_gen = await generate_key(session=session)
375
        key = key_gen["key"]
376
        await embeddings(session=session, key=key)
377
        key_gen = await new_user(session=session)
378
        key_2 = key_gen["key"]
379
        await embeddings(session=session, key=key_2)
380

381

382
@pytest.mark.asyncio
383
async def test_image_generation():
384
    """
385
    - Create key
386
    Make embeddings call
387
    - Create user
388
    make embeddings call
389
    """
390
    async with aiohttp.ClientSession() as session:
391
        key_gen = await generate_key(session=session)
392
        key = key_gen["key"]
393
        await image_generation(session=session, key=key)
394
        key_gen = await new_user(session=session)
395
        key_2 = key_gen["key"]
396
        await image_generation(session=session, key=key_2)
397

398

399
@pytest.mark.asyncio
400
async def test_openai_wildcard_chat_completion():
401
    """
402
    - Create key for model = "*" -> this has access to all models
403
    - proxy_server_config.yaml has model = *
404
    - Make chat completion call
405

406
    """
407
    async with aiohttp.ClientSession() as session:
408
        key_gen = await generate_key(session=session, models=["*"])
409
        key = key_gen["key"]
410

411
        # call chat/completions with a model that the key was not created for + the model is not on the config.yaml
412
        await chat_completion(session=session, key=key, model="gpt-3.5-turbo-0125")
413

414

415
@pytest.mark.asyncio
416
async def test_batch_chat_completions():
417
    """
418
    - Make chat completion call using
419

420
    """
421
    async with aiohttp.ClientSession() as session:
422

423
        # call chat/completions with a model that the key was not created for + the model is not on the config.yaml
424
        response = await chat_completion(
425
            session=session,
426
            key="sk-1234",
427
            model="gpt-3.5-turbo,fake-openai-endpoint",
428
        )
429

430
        print(f"response: {response}")
431

432
        assert len(response) == 2
433
        assert isinstance(response, list)
434

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.