instructor
103 строки · 2.8 Кб
1from typing import Any, Callable, cast
2import pytest
3import instructor
4
5from openai import OpenAI
6from pydantic import BaseModel
7
8from instructor.distil import (
9Instructions,
10format_function,
11get_signature_from_fn,
12is_return_type_base_model_or_instance,
13)
14
15client = instructor.patch(OpenAI())
16
17instructions = Instructions(
18name="test_distil",
19)
20
21
22class SimpleModel(BaseModel): # type: ignore[misc]
23data: int
24
25
26def test_must_have_hint() -> None:
27with pytest.raises(AssertionError):
28
29@instructions.distil
30def test_func(x: int): # type: ignore[no-untyped-def]
31return SimpleModel(data=x)
32
33
34def test_must_be_base_model() -> None:
35with pytest.raises(AssertionError):
36
37@instructions.distil
38def test_func(x: int) -> int:
39return SimpleModel(data=x)
40
41
42def test_is_return_type_base_model_or_instance() -> None:
43def valid_function() -> SimpleModel:
44return SimpleModel(data=1)
45
46def invalid_function() -> int:
47return 1
48
49assert is_return_type_base_model_or_instance(valid_function)
50assert not is_return_type_base_model_or_instance(invalid_function)
51
52
53def test_get_signature_from_fn() -> None:
54def test_function(a: int, b: str) -> float: # type: ignore[empty-body]
55"""Sample docstring"""
56pass
57
58result = get_signature_from_fn(test_function)
59expected = "def test_function(a: int, b: str) -> float"
60assert expected in result
61assert "Sample docstring" in result
62
63
64def test_format_function() -> None:
65def sample_function(x: int) -> SimpleModel:
66"""This is a docstring."""
67return SimpleModel(data=x)
68
69formatted = format_function(sample_function)
70assert "def sample_function(x: int) -> SimpleModel:" in formatted
71assert '"""This is a docstring."""' in formatted
72assert "return SimpleModel(data=x)" in formatted
73
74
75def test_distil_decorator_without_arguments() -> None:
76@instructions.distil
77def test_func(x: int) -> SimpleModel:
78return SimpleModel(data=x)
79
80casted_test_func = cast(Callable[[int], SimpleModel], test_func)
81result: SimpleModel = casted_test_func(42)
82assert result.data == 42
83
84
85def test_distil_decorator_with_name_argument() -> None:
86@instructions.distil(name="custom_name")
87def another_test_func(x: int) -> SimpleModel:
88return SimpleModel(data=x)
89
90casted_another_test_func = cast(Callable[[int], SimpleModel], another_test_func)
91result: SimpleModel = casted_another_test_func(55)
92assert result.data == 55
93
94
95# Mock track function for decorator tests
96def mock_track(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> None:
97pass
98
99
100def fn(a: int, b: int) -> int:
101return client.chat.completions.create(
102messages=[], model="davinci", response_model=SimpleModel
103)
104