openai-python

Форк
0
/
utils.py 
130 строк · 3.8 Кб
1
from __future__ import annotations
2

3
import os
4
import inspect
5
import traceback
6
import contextlib
7
from typing import Any, TypeVar, Iterator, cast
8
from datetime import date, datetime
9
from typing_extensions import Literal, get_args, get_origin, assert_type
10

11
from openai._types import NoneType
12
from openai._utils import (
13
    is_dict,
14
    is_list,
15
    is_list_type,
16
    is_union_type,
17
)
18
from openai._compat import PYDANTIC_V2, field_outer_type, get_model_fields
19
from openai._models import BaseModel
20

21
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
22

23

24
def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool:
25
    for name, field in get_model_fields(model).items():
26
        field_value = getattr(value, name)
27
        if PYDANTIC_V2:
28
            allow_none = False
29
        else:
30
            # in v1 nullability was structured differently
31
            # https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields
32
            allow_none = getattr(field, "allow_none", False)
33

34
        assert_matches_type(
35
            field_outer_type(field),
36
            field_value,
37
            path=[*path, name],
38
            allow_none=allow_none,
39
        )
40

41
    return True
42

43

44
# Note: the `path` argument is only used to improve error messages when `--showlocals` is used
45
def assert_matches_type(
46
    type_: Any,
47
    value: object,
48
    *,
49
    path: list[str],
50
    allow_none: bool = False,
51
) -> None:
52
    if allow_none and value is None:
53
        return
54

55
    if type_ is None or type_ is NoneType:
56
        assert value is None
57
        return
58

59
    origin = get_origin(type_) or type_
60

61
    if is_list_type(type_):
62
        return _assert_list_type(type_, value)
63

64
    if origin == str:
65
        assert isinstance(value, str)
66
    elif origin == int:
67
        assert isinstance(value, int)
68
    elif origin == bool:
69
        assert isinstance(value, bool)
70
    elif origin == float:
71
        assert isinstance(value, float)
72
    elif origin == bytes:
73
        assert isinstance(value, bytes)
74
    elif origin == datetime:
75
        assert isinstance(value, datetime)
76
    elif origin == date:
77
        assert isinstance(value, date)
78
    elif origin == object:
79
        # nothing to do here, the expected type is unknown
80
        pass
81
    elif origin == Literal:
82
        assert value in get_args(type_)
83
    elif origin == dict:
84
        assert is_dict(value)
85

86
        args = get_args(type_)
87
        key_type = args[0]
88
        items_type = args[1]
89

90
        for key, item in value.items():
91
            assert_matches_type(key_type, key, path=[*path, "<dict key>"])
92
            assert_matches_type(items_type, item, path=[*path, "<dict item>"])
93
    elif is_union_type(type_):
94
        for i, variant in enumerate(get_args(type_)):
95
            try:
96
                assert_matches_type(variant, value, path=[*path, f"variant {i}"])
97
                return
98
            except AssertionError:
99
                traceback.print_exc()
100
                continue
101

102
        raise AssertionError("Did not match any variants")
103
    elif issubclass(origin, BaseModel):
104
        assert isinstance(value, type_)
105
        assert assert_matches_model(type_, cast(Any, value), path=path)
106
    elif inspect.isclass(origin) and origin.__name__ == "HttpxBinaryResponseContent":
107
        assert value.__class__.__name__ == "HttpxBinaryResponseContent"
108
    else:
109
        assert None, f"Unhandled field type: {type_}"
110

111

112
def _assert_list_type(type_: type[object], value: object) -> None:
113
    assert is_list(value)
114

115
    inner_type = get_args(type_)[0]
116
    for entry in value:
117
        assert_type(inner_type, entry)  # type: ignore
118

119

120
@contextlib.contextmanager
121
def update_env(**new_env: str) -> Iterator[None]:
122
    old = os.environ.copy()
123

124
    try:
125
        os.environ.update(new_env)
126

127
        yield None
128
    finally:
129
        os.environ.clear()
130
        os.environ.update(old)
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.