openai-python
159 строк · 4.5 Кб
1import json2from typing import List3
4import httpx5import pytest6import pydantic7
8from openai import OpenAI, BaseModel, AsyncOpenAI9from openai._response import (10APIResponse,11BaseAPIResponse,12AsyncAPIResponse,13BinaryAPIResponse,14AsyncBinaryAPIResponse,15extract_response_type,16)
17from openai._streaming import Stream18from openai._base_client import FinalRequestOptions19
20
21class ConcreteBaseAPIResponse(APIResponse[bytes]):22...23
24
25class ConcreteAPIResponse(APIResponse[List[str]]):26...27
28
29class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]):30...31
32
33def test_extract_response_type_direct_classes() -> None:34assert extract_response_type(BaseAPIResponse[str]) == str35assert extract_response_type(APIResponse[str]) == str36assert extract_response_type(AsyncAPIResponse[str]) == str37
38
39def test_extract_response_type_direct_class_missing_type_arg() -> None:40with pytest.raises(41RuntimeError,42match="Expected type <class 'openai._response.AsyncAPIResponse'> to have a type argument at index 0 but it did not",43):44extract_response_type(AsyncAPIResponse)45
46
47def test_extract_response_type_concrete_subclasses() -> None:48assert extract_response_type(ConcreteBaseAPIResponse) == bytes49assert extract_response_type(ConcreteAPIResponse) == List[str]50assert extract_response_type(ConcreteAsyncAPIResponse) == httpx.Response51
52
53def test_extract_response_type_binary_response() -> None:54assert extract_response_type(BinaryAPIResponse) == bytes55assert extract_response_type(AsyncBinaryAPIResponse) == bytes56
57
58class PydanticModel(pydantic.BaseModel):59...60
61
62def test_response_parse_mismatched_basemodel(client: OpenAI) -> None:63response = APIResponse(64raw=httpx.Response(200, content=b"foo"),65client=client,66stream=False,67stream_cls=None,68cast_to=str,69options=FinalRequestOptions.construct(method="get", url="/foo"),70)71
72with pytest.raises(73TypeError,74match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",75):76response.parse(to=PydanticModel)77
78
79@pytest.mark.asyncio80async def test_async_response_parse_mismatched_basemodel(async_client: AsyncOpenAI) -> None:81response = AsyncAPIResponse(82raw=httpx.Response(200, content=b"foo"),83client=async_client,84stream=False,85stream_cls=None,86cast_to=str,87options=FinalRequestOptions.construct(method="get", url="/foo"),88)89
90with pytest.raises(91TypeError,92match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",93):94await response.parse(to=PydanticModel)95
96
97def test_response_parse_custom_stream(client: OpenAI) -> None:98response = APIResponse(99raw=httpx.Response(200, content=b"foo"),100client=client,101stream=True,102stream_cls=None,103cast_to=str,104options=FinalRequestOptions.construct(method="get", url="/foo"),105)106
107stream = response.parse(to=Stream[int])108assert stream._cast_to == int109
110
111@pytest.mark.asyncio112async def test_async_response_parse_custom_stream(async_client: AsyncOpenAI) -> None:113response = AsyncAPIResponse(114raw=httpx.Response(200, content=b"foo"),115client=async_client,116stream=True,117stream_cls=None,118cast_to=str,119options=FinalRequestOptions.construct(method="get", url="/foo"),120)121
122stream = await response.parse(to=Stream[int])123assert stream._cast_to == int124
125
126class CustomModel(BaseModel):127foo: str128bar: int129
130
131def test_response_parse_custom_model(client: OpenAI) -> None:132response = APIResponse(133raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),134client=client,135stream=False,136stream_cls=None,137cast_to=str,138options=FinalRequestOptions.construct(method="get", url="/foo"),139)140
141obj = response.parse(to=CustomModel)142assert obj.foo == "hello!"143assert obj.bar == 2144
145
146@pytest.mark.asyncio147async def test_async_response_parse_custom_model(async_client: AsyncOpenAI) -> None:148response = AsyncAPIResponse(149raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),150client=async_client,151stream=False,152stream_cls=None,153cast_to=str,154options=FinalRequestOptions.construct(method="get", url="/foo"),155)156
157obj = await response.parse(to=CustomModel)158assert obj.foo == "hello!"159assert obj.bar == 2160