instructor

Форк
0
251 строка · 6.0 Кб
1
from enum import Enum
2
from typing import Literal
3

4
import anthropic
5
import pytest
6
from pydantic import BaseModel, field_validator
7

8
import instructor
9
from instructor.retry import InstructorRetryException
10

11
client = instructor.from_anthropic(
12
    anthropic.Anthropic(), mode=instructor.Mode.ANTHROPIC_TOOLS
13
)
14

15

16
def test_simple():
17
    class User(BaseModel):
18
        name: str
19
        age: int
20

21
        @field_validator("name")
22
        def name_is_uppercase(cls, v: str):
23
            assert v.isupper(), "Name must be uppercase, please fix"
24
            return v
25

26
    resp = client.messages.create(
27
        model="claude-3-haiku-20240307",
28
        max_tokens=1024,
29
        max_retries=2,
30
        messages=[
31
            {
32
                "role": "user",
33
                "content": "Extract John is 18 years old.",
34
            }
35
        ],
36
        response_model=User,
37
    )  # type: ignore
38

39
    assert isinstance(resp, User)
40
    assert resp.name == "JOHN"  # due to validation
41
    assert resp.age == 18
42

43

44
def test_nested_type():
45
    class Address(BaseModel):
46
        house_number: int
47
        street_name: str
48

49
    class User(BaseModel):
50
        name: str
51
        age: int
52
        address: Address
53

54
    resp = client.messages.create(
55
        model="claude-3-haiku-20240307",
56
        max_tokens=1024,
57
        max_retries=0,
58
        messages=[
59
            {
60
                "role": "user",
61
                "content": "Extract John is 18 years old and lives at 123 First Avenue.",
62
            }
63
        ],
64
        response_model=User,
65
    )  # type: ignore
66

67
    assert isinstance(resp, User)
68
    assert resp.name == "John"
69
    assert resp.age == 18
70

71
    assert isinstance(resp.address, Address)
72
    assert resp.address.house_number == 123
73
    assert resp.address.street_name == "First Avenue"
74

75

76
def test_list_str():
77
    class User(BaseModel):
78
        name: str
79
        age: int
80
        family: list[str]
81

82
    resp = client.messages.create(
83
        model="claude-3-haiku-20240307",
84
        max_tokens=1024,
85
        max_retries=0,
86
        messages=[
87
            {
88
                "role": "user",
89
                "content": "Create a user for a model with a name, age, and family members.",
90
            }
91
        ],
92
        response_model=User,
93
    )
94

95
    assert isinstance(resp, User)
96
    assert isinstance(resp.family, list)
97
    for member in resp.family:
98
        assert isinstance(member, str)
99

100

101
@pytest.mark.skip("Just use Literal!")
102
def test_enum():
103
    class Role(str, Enum):
104
        ADMIN = "admin"
105
        USER = "user"
106

107
    class User(BaseModel):
108
        name: str
109
        role: Role
110

111
    resp = client.messages.create(
112
        model="claude-3-haiku-20240307",
113
        max_tokens=1024,
114
        max_retries=1,
115
        messages=[
116
            {
117
                "role": "user",
118
                "content": "Create a user for a model with a name and role of admin.",
119
            }
120
        ],
121
        response_model=User,
122
    )
123

124
    assert isinstance(resp, User)
125
    assert resp.role == Role.ADMIN
126

127

128
def test_literal():
129
    class User(BaseModel):
130
        name: str
131
        role: Literal["admin", "user"]
132

133
    resp = client.messages.create(
134
        model="claude-3-haiku-20240307",
135
        max_tokens=1024,
136
        max_retries=2,
137
        messages=[
138
            {
139
                "role": "user",
140
                "content": "Create a admin user for a model with a name and role.",
141
            }
142
        ],
143
        response_model=User,
144
    )  # type: ignore
145

146
    assert isinstance(resp, User)
147
    assert resp.role == "admin"
148

149

150
def test_nested_list():
151
    class Properties(BaseModel):
152
        key: str
153
        value: str
154

155
    class User(BaseModel):
156
        name: str
157
        age: int
158
        properties: list[Properties]
159

160
    resp = client.messages.create(
161
        model="claude-3-haiku-20240307",
162
        max_tokens=1024,
163
        max_retries=0,
164
        messages=[
165
            {
166
                "role": "user",
167
                "content": "Create a user for a model with a name, age, and properties.",
168
            }
169
        ],
170
        response_model=User,
171
    )
172

173
    assert isinstance(resp, User)
174
    for property in resp.properties:
175
        assert isinstance(property, Properties)
176

177

178
def test_system_messages_allcaps():
179
    class User(BaseModel):
180
        name: str
181
        age: int
182

183
    resp = client.messages.create(
184
        model="claude-3-haiku-20240307",
185
        max_tokens=1024,
186
        max_retries=0,
187
        messages=[
188
            {"role": "system", "content": "EVERYTHING MUST BE IN ALL CAPS"},
189
            {
190
                "role": "user",
191
                "content": "Create a user for a model with a name and age.",
192
            },
193
        ],
194
        response_model=User,
195
    )
196

197
    assert isinstance(resp, User)
198
    assert resp.name.isupper()
199

200

201
def test_retry_error():
202
    class User(BaseModel):
203
        name: str
204

205
        @field_validator("name")
206
        def validate_name(cls, _):
207
            raise ValueError("Never succeed")
208

209
    try:
210
        client.messages.create(
211
            model="claude-3-haiku-20240307",
212
            max_tokens=1024,
213
            max_retries=2,
214
            messages=[
215
                {
216
                    "role": "user",
217
                    "content": "Extract John is 18 years old",
218
                },
219
            ],
220
            response_model=User,
221
        )
222
    except InstructorRetryException as e:
223
        assert e.total_usage.input_tokens > 0 and e.total_usage.output_tokens > 0
224

225

226
@pytest.mark.asyncio
227
async def test_async_retry_error():
228
    client = instructor.from_anthropic(anthropic.AsyncAnthropic())
229

230
    class User(BaseModel):
231
        name: str
232

233
        @field_validator("name")
234
        def validate_name(cls, _):
235
            raise ValueError("Never succeed")
236

237
    try:
238
        await client.messages.create(
239
            model="claude-3-haiku-20240307",
240
            max_tokens=1024,
241
            max_retries=2,
242
            messages=[
243
                {
244
                    "role": "user",
245
                    "content": "Extract John is 18 years old",
246
                },
247
            ],
248
            response_model=User,
249
        )
250
    except InstructorRetryException as e:
251
        assert e.total_usage.input_tokens > 0 and e.total_usage.output_tokens > 0
252

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.