openai-python
130 строк · 3.8 Кб
1from __future__ import annotations
2
3import os
4import inspect
5import traceback
6import contextlib
7from typing import Any, TypeVar, Iterator, cast
8from datetime import date, datetime
9from typing_extensions import Literal, get_args, get_origin, assert_type
10
11from openai._types import NoneType
12from openai._utils import (
13is_dict,
14is_list,
15is_list_type,
16is_union_type,
17)
18from openai._compat import PYDANTIC_V2, field_outer_type, get_model_fields
19from openai._models import BaseModel
20
21BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
22
23
24def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool:
25for name, field in get_model_fields(model).items():
26field_value = getattr(value, name)
27if PYDANTIC_V2:
28allow_none = False
29else:
30# in v1 nullability was structured differently
31# https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields
32allow_none = getattr(field, "allow_none", False)
33
34assert_matches_type(
35field_outer_type(field),
36field_value,
37path=[*path, name],
38allow_none=allow_none,
39)
40
41return True
42
43
44# Note: the `path` argument is only used to improve error messages when `--showlocals` is used
45def assert_matches_type(
46type_: Any,
47value: object,
48*,
49path: list[str],
50allow_none: bool = False,
51) -> None:
52if allow_none and value is None:
53return
54
55if type_ is None or type_ is NoneType:
56assert value is None
57return
58
59origin = get_origin(type_) or type_
60
61if is_list_type(type_):
62return _assert_list_type(type_, value)
63
64if origin == str:
65assert isinstance(value, str)
66elif origin == int:
67assert isinstance(value, int)
68elif origin == bool:
69assert isinstance(value, bool)
70elif origin == float:
71assert isinstance(value, float)
72elif origin == bytes:
73assert isinstance(value, bytes)
74elif origin == datetime:
75assert isinstance(value, datetime)
76elif origin == date:
77assert isinstance(value, date)
78elif origin == object:
79# nothing to do here, the expected type is unknown
80pass
81elif origin == Literal:
82assert value in get_args(type_)
83elif origin == dict:
84assert is_dict(value)
85
86args = get_args(type_)
87key_type = args[0]
88items_type = args[1]
89
90for key, item in value.items():
91assert_matches_type(key_type, key, path=[*path, "<dict key>"])
92assert_matches_type(items_type, item, path=[*path, "<dict item>"])
93elif is_union_type(type_):
94for i, variant in enumerate(get_args(type_)):
95try:
96assert_matches_type(variant, value, path=[*path, f"variant {i}"])
97return
98except AssertionError:
99traceback.print_exc()
100continue
101
102raise AssertionError("Did not match any variants")
103elif issubclass(origin, BaseModel):
104assert isinstance(value, type_)
105assert assert_matches_model(type_, cast(Any, value), path=path)
106elif inspect.isclass(origin) and origin.__name__ == "HttpxBinaryResponseContent":
107assert value.__class__.__name__ == "HttpxBinaryResponseContent"
108else:
109assert None, f"Unhandled field type: {type_}"
110
111
112def _assert_list_type(type_: type[object], value: object) -> None:
113assert is_list(value)
114
115inner_type = get_args(type_)[0]
116for entry in value:
117assert_type(inner_type, entry) # type: ignore
118
119
120@contextlib.contextmanager
121def update_env(**new_env: str) -> Iterator[None]:
122old = os.environ.copy()
123
124try:
125os.environ.update(new_env)
126
127yield None
128finally:
129os.environ.clear()
130os.environ.update(old)
131