instructor
100 строк · 3.1 Кб
1from pydantic import BaseModel, ValidationInfo, model_validator
2import openai
3import instructor
4import asyncio
5
6client = instructor.from_openai(
7openai.AsyncOpenAI(),
8)
9
10
11class Tag(BaseModel):
12id: int
13name: str
14
15@model_validator(mode="after")
16def validate_ids(self, info: ValidationInfo):
17context = info.context
18if context:
19tags: list[Tag] = context.get("tags")
20assert self.id in {
21tag.id for tag in tags
22}, f"Tag ID {self.id} not found in context"
23assert self.name in {
24tag.name for tag in tags
25}, f"Tag name {self.name} not found in context"
26return self
27
28
29class TagWithInstructions(Tag):
30instructions: str
31
32
33class TagRequest(BaseModel):
34texts: list[str]
35tags: list[TagWithInstructions]
36
37
38class TagResponse(BaseModel):
39texts: list[str]
40predictions: list[Tag]
41
42
43async def tag_single_request(text: str, tags: list[Tag]) -> Tag:
44allowed_tags = [(tag.id, tag.name) for tag in tags]
45allowed_tags_str = ", ".join([f"`{tag}`" for tag in allowed_tags])
46return await client.chat.completions.create(
47model="gpt-4-turbo-preview",
48messages=[
49{
50"role": "system",
51"content": "You are a world-class text tagging system.",
52},
53{"role": "user", "content": f"Describe the following text: `{text}`"},
54{
55"role": "user",
56"content": f"Here are the allowed tags: {allowed_tags_str}",
57},
58],
59response_model=Tag,
60# Minizises the hallucination of tags that are not in the allowed tags.
61validation_context={"tags": tags},
62)
63
64
65async def tag_request(request: TagRequest) -> TagResponse:
66predictions = await asyncio.gather(
67*[tag_single_request(text, request.tags) for text in request.texts]
68)
69return TagResponse(
70texts=request.texts,
71predictions=predictions,
72)
73
74
75if __name__ == "__main__":
76# Tags will be a range of different topics.
77# Such as personal, phone, email, etc.
78tags = [
79TagWithInstructions(id=0, name="personal", instructions="Personal information"),
80TagWithInstructions(id=1, name="phone", instructions="Phone number"),
81TagWithInstructions(id=2, name="email", instructions="Email address"),
82TagWithInstructions(id=3, name="address", instructions="Address"),
83TagWithInstructions(id=4, name="Other", instructions="Other information"),
84]
85
86# Texts will be a range of different questions.
87# Such as "How much does it cost?", "What is your privacy policy?", etc.
88texts = [
89"What is your phone number?",
90"What is your email address?",
91"What is your address?",
92"What is your privacy policy?",
93]
94
95# The request will contain the texts and the tags.
96request = TagRequest(texts=texts, tags=tags)
97
98# The response will contain the texts, the predicted tags, and the confidence.
99response = asyncio.run(tag_request(request))
100print(response.model_dump_json(indent=2))
101