openai-python

Форк
0
/
test_transform.py 
299 строк · 9.7 Кб
1
from __future__ import annotations
2

3
from typing import Any, List, Union, Iterable, Optional, cast
4
from datetime import date, datetime
5
from typing_extensions import Required, Annotated, TypedDict
6

7
import pytest
8

9
from openai._utils import PropertyInfo, transform, parse_datetime
10
from openai._compat import PYDANTIC_V2
11
from openai._models import BaseModel
12

13

14
class Foo1(TypedDict):
15
    foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
16

17

18
def test_top_level_alias() -> None:
19
    assert transform({"foo_bar": "hello"}, expected_type=Foo1) == {"fooBar": "hello"}
20

21

22
class Foo2(TypedDict):
23
    bar: Bar2
24

25

26
class Bar2(TypedDict):
27
    this_thing: Annotated[int, PropertyInfo(alias="this__thing")]
28
    baz: Annotated[Baz2, PropertyInfo(alias="Baz")]
29

30

31
class Baz2(TypedDict):
32
    my_baz: Annotated[str, PropertyInfo(alias="myBaz")]
33

34

35
def test_recursive_typeddict() -> None:
36
    assert transform({"bar": {"this_thing": 1}}, Foo2) == {"bar": {"this__thing": 1}}
37
    assert transform({"bar": {"baz": {"my_baz": "foo"}}}, Foo2) == {"bar": {"Baz": {"myBaz": "foo"}}}
38

39

40
class Foo3(TypedDict):
41
    things: List[Bar3]
42

43

44
class Bar3(TypedDict):
45
    my_field: Annotated[str, PropertyInfo(alias="myField")]
46

47

48
def test_list_of_typeddict() -> None:
49
    result = transform({"things": [{"my_field": "foo"}, {"my_field": "foo2"}]}, expected_type=Foo3)
50
    assert result == {"things": [{"myField": "foo"}, {"myField": "foo2"}]}
51

52

53
class Foo4(TypedDict):
54
    foo: Union[Bar4, Baz4]
55

56

57
class Bar4(TypedDict):
58
    foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
59

60

61
class Baz4(TypedDict):
62
    foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
63

64

65
def test_union_of_typeddict() -> None:
66
    assert transform({"foo": {"foo_bar": "bar"}}, Foo4) == {"foo": {"fooBar": "bar"}}
67
    assert transform({"foo": {"foo_baz": "baz"}}, Foo4) == {"foo": {"fooBaz": "baz"}}
68
    assert transform({"foo": {"foo_baz": "baz", "foo_bar": "bar"}}, Foo4) == {"foo": {"fooBaz": "baz", "fooBar": "bar"}}
69

70

71
class Foo5(TypedDict):
72
    foo: Annotated[Union[Bar4, List[Baz4]], PropertyInfo(alias="FOO")]
73

74

75
class Bar5(TypedDict):
76
    foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
77

78

79
class Baz5(TypedDict):
80
    foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
81

82

83
def test_union_of_list() -> None:
84
    assert transform({"foo": {"foo_bar": "bar"}}, Foo5) == {"FOO": {"fooBar": "bar"}}
85
    assert transform(
86
        {
87
            "foo": [
88
                {"foo_baz": "baz"},
89
                {"foo_baz": "baz"},
90
            ]
91
        },
92
        Foo5,
93
    ) == {"FOO": [{"fooBaz": "baz"}, {"fooBaz": "baz"}]}
94

95

96
class Foo6(TypedDict):
97
    bar: Annotated[str, PropertyInfo(alias="Bar")]
98

99

100
def test_includes_unknown_keys() -> None:
101
    assert transform({"bar": "bar", "baz_": {"FOO": 1}}, Foo6) == {
102
        "Bar": "bar",
103
        "baz_": {"FOO": 1},
104
    }
105

106

107
class Foo7(TypedDict):
108
    bar: Annotated[List[Bar7], PropertyInfo(alias="bAr")]
109
    foo: Bar7
110

111

112
class Bar7(TypedDict):
113
    foo: str
114

115

116
def test_ignores_invalid_input() -> None:
117
    assert transform({"bar": "<foo>"}, Foo7) == {"bAr": "<foo>"}
118
    assert transform({"foo": "<foo>"}, Foo7) == {"foo": "<foo>"}
119

120

121
class DatetimeDict(TypedDict, total=False):
122
    foo: Annotated[datetime, PropertyInfo(format="iso8601")]
123

124
    bar: Annotated[Optional[datetime], PropertyInfo(format="iso8601")]
125

126
    required: Required[Annotated[Optional[datetime], PropertyInfo(format="iso8601")]]
127

128
    list_: Required[Annotated[Optional[List[datetime]], PropertyInfo(format="iso8601")]]
129

130
    union: Annotated[Union[int, datetime], PropertyInfo(format="iso8601")]
131

132

133
class DateDict(TypedDict, total=False):
134
    foo: Annotated[date, PropertyInfo(format="iso8601")]
135

136

137
def test_iso8601_format() -> None:
138
    dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
139
    assert transform({"foo": dt}, DatetimeDict) == {"foo": "2023-02-23T14:16:36.337692+00:00"}  # type: ignore[comparison-overlap]
140

141
    dt = dt.replace(tzinfo=None)
142
    assert transform({"foo": dt}, DatetimeDict) == {"foo": "2023-02-23T14:16:36.337692"}  # type: ignore[comparison-overlap]
143

144
    assert transform({"foo": None}, DateDict) == {"foo": None}  # type: ignore[comparison-overlap]
145
    assert transform({"foo": date.fromisoformat("2023-02-23")}, DateDict) == {"foo": "2023-02-23"}  # type: ignore[comparison-overlap]
146

147

148
def test_optional_iso8601_format() -> None:
149
    dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
150
    assert transform({"bar": dt}, DatetimeDict) == {"bar": "2023-02-23T14:16:36.337692+00:00"}  # type: ignore[comparison-overlap]
151

152
    assert transform({"bar": None}, DatetimeDict) == {"bar": None}
153

154

155
def test_required_iso8601_format() -> None:
156
    dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
157
    assert transform({"required": dt}, DatetimeDict) == {"required": "2023-02-23T14:16:36.337692+00:00"}  # type: ignore[comparison-overlap]
158

159
    assert transform({"required": None}, DatetimeDict) == {"required": None}
160

161

162
def test_union_datetime() -> None:
163
    dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
164
    assert transform({"union": dt}, DatetimeDict) == {  # type: ignore[comparison-overlap]
165
        "union": "2023-02-23T14:16:36.337692+00:00"
166
    }
167

168
    assert transform({"union": "foo"}, DatetimeDict) == {"union": "foo"}
169

170

171
def test_nested_list_iso6801_format() -> None:
172
    dt1 = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
173
    dt2 = parse_datetime("2022-01-15T06:34:23Z")
174
    assert transform({"list_": [dt1, dt2]}, DatetimeDict) == {  # type: ignore[comparison-overlap]
175
        "list_": ["2023-02-23T14:16:36.337692+00:00", "2022-01-15T06:34:23+00:00"]
176
    }
177

178

179
def test_datetime_custom_format() -> None:
180
    dt = parse_datetime("2022-01-15T06:34:23Z")
181

182
    result = transform(dt, Annotated[datetime, PropertyInfo(format="custom", format_template="%H")])
183
    assert result == "06"  # type: ignore[comparison-overlap]
184

185

186
class DateDictWithRequiredAlias(TypedDict, total=False):
187
    required_prop: Required[Annotated[date, PropertyInfo(format="iso8601", alias="prop")]]
188

189

190
def test_datetime_with_alias() -> None:
191
    assert transform({"required_prop": None}, DateDictWithRequiredAlias) == {"prop": None}  # type: ignore[comparison-overlap]
192
    assert transform({"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias) == {
193
        "prop": "2023-02-23"
194
    }  # type: ignore[comparison-overlap]
195

196

197
class MyModel(BaseModel):
198
    foo: str
199

200

201
def test_pydantic_model_to_dictionary() -> None:
202
    assert transform(MyModel(foo="hi!"), Any) == {"foo": "hi!"}
203
    assert transform(MyModel.construct(foo="hi!"), Any) == {"foo": "hi!"}
204

205

206
def test_pydantic_empty_model() -> None:
207
    assert transform(MyModel.construct(), Any) == {}
208

209

210
def test_pydantic_unknown_field() -> None:
211
    assert transform(MyModel.construct(my_untyped_field=True), Any) == {"my_untyped_field": True}
212

213

214
def test_pydantic_mismatched_types() -> None:
215
    model = MyModel.construct(foo=True)
216
    if PYDANTIC_V2:
217
        with pytest.warns(UserWarning):
218
            params = transform(model, Any)
219
    else:
220
        params = transform(model, Any)
221
    assert params == {"foo": True}
222

223

224
def test_pydantic_mismatched_object_type() -> None:
225
    model = MyModel.construct(foo=MyModel.construct(hello="world"))
226
    if PYDANTIC_V2:
227
        with pytest.warns(UserWarning):
228
            params = transform(model, Any)
229
    else:
230
        params = transform(model, Any)
231
    assert params == {"foo": {"hello": "world"}}
232

233

234
class ModelNestedObjects(BaseModel):
235
    nested: MyModel
236

237

238
def test_pydantic_nested_objects() -> None:
239
    model = ModelNestedObjects.construct(nested={"foo": "stainless"})
240
    assert isinstance(model.nested, MyModel)
241
    assert transform(model, Any) == {"nested": {"foo": "stainless"}}
242

243

244
class ModelWithDefaultField(BaseModel):
245
    foo: str
246
    with_none_default: Union[str, None] = None
247
    with_str_default: str = "foo"
248

249

250
def test_pydantic_default_field() -> None:
251
    # should be excluded when defaults are used
252
    model = ModelWithDefaultField.construct()
253
    assert model.with_none_default is None
254
    assert model.with_str_default == "foo"
255
    assert transform(model, Any) == {}
256

257
    # should be included when the default value is explicitly given
258
    model = ModelWithDefaultField.construct(with_none_default=None, with_str_default="foo")
259
    assert model.with_none_default is None
260
    assert model.with_str_default == "foo"
261
    assert transform(model, Any) == {"with_none_default": None, "with_str_default": "foo"}
262

263
    # should be included when a non-default value is explicitly given
264
    model = ModelWithDefaultField.construct(with_none_default="bar", with_str_default="baz")
265
    assert model.with_none_default == "bar"
266
    assert model.with_str_default == "baz"
267
    assert transform(model, Any) == {"with_none_default": "bar", "with_str_default": "baz"}
268

269

270
class TypedDictIterableUnion(TypedDict):
271
    foo: Annotated[Union[Bar8, Iterable[Baz8]], PropertyInfo(alias="FOO")]
272

273

274
class Bar8(TypedDict):
275
    foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
276

277

278
class Baz8(TypedDict):
279
    foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
280

281

282
def test_iterable_of_dictionaries() -> None:
283
    assert transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "bar"}]}
284
    assert cast(Any, transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion)) == {"FOO": [{"fooBaz": "bar"}]}
285

286
    def my_iter() -> Iterable[Baz8]:
287
        yield {"foo_baz": "hello"}
288
        yield {"foo_baz": "world"}
289

290
    assert transform({"foo": my_iter()}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}]}
291

292

293
class TypedDictIterableUnionStr(TypedDict):
294
    foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")]
295

296

297
def test_iterable_union_str() -> None:
298
    assert transform({"foo": "bar"}, TypedDictIterableUnionStr) == {"FOO": "bar"}
299
    assert cast(Any, transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]])) == [{"fooBaz": "bar"}]
300

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.