instructor

Форк
0
100 строк · 3.1 Кб
1
from pydantic import BaseModel, ValidationInfo, model_validator
2
import openai
3
import instructor
4
import asyncio
5

6
client = instructor.from_openai(
7
    openai.AsyncOpenAI(),
8
)
9

10

11
class Tag(BaseModel):
12
    id: int
13
    name: str
14

15
    @model_validator(mode="after")
16
    def validate_ids(self, info: ValidationInfo):
17
        context = info.context
18
        if context:
19
            tags: list[Tag] = context.get("tags")
20
            assert self.id in {
21
                tag.id for tag in tags
22
            }, f"Tag ID {self.id} not found in context"
23
            assert self.name in {
24
                tag.name for tag in tags
25
            }, f"Tag name {self.name} not found in context"
26
        return self
27

28

29
class TagWithInstructions(Tag):
30
    instructions: str
31

32

33
class TagRequest(BaseModel):
34
    texts: list[str]
35
    tags: list[TagWithInstructions]
36

37

38
class TagResponse(BaseModel):
39
    texts: list[str]
40
    predictions: list[Tag]
41

42

43
async def tag_single_request(text: str, tags: list[Tag]) -> Tag:
44
    allowed_tags = [(tag.id, tag.name) for tag in tags]
45
    allowed_tags_str = ", ".join([f"`{tag}`" for tag in allowed_tags])
46
    return await client.chat.completions.create(
47
        model="gpt-4-turbo-preview",
48
        messages=[
49
            {
50
                "role": "system",
51
                "content": "You are a world-class text tagging system.",
52
            },
53
            {"role": "user", "content": f"Describe the following text: `{text}`"},
54
            {
55
                "role": "user",
56
                "content": f"Here are the allowed tags: {allowed_tags_str}",
57
            },
58
        ],
59
        response_model=Tag,
60
        # Minizises the hallucination of tags that are not in the allowed tags.
61
        validation_context={"tags": tags},
62
    )
63

64

65
async def tag_request(request: TagRequest) -> TagResponse:
66
    predictions = await asyncio.gather(
67
        *[tag_single_request(text, request.tags) for text in request.texts]
68
    )
69
    return TagResponse(
70
        texts=request.texts,
71
        predictions=predictions,
72
    )
73

74

75
if __name__ == "__main__":
76
    # Tags will be a range of different topics.
77
    # Such as personal, phone, email, etc.
78
    tags = [
79
        TagWithInstructions(id=0, name="personal", instructions="Personal information"),
80
        TagWithInstructions(id=1, name="phone", instructions="Phone number"),
81
        TagWithInstructions(id=2, name="email", instructions="Email address"),
82
        TagWithInstructions(id=3, name="address", instructions="Address"),
83
        TagWithInstructions(id=4, name="Other", instructions="Other information"),
84
    ]
85

86
    # Texts will be a range of different questions.
87
    # Such as "How much does it cost?", "What is your privacy policy?", etc.
88
    texts = [
89
        "What is your phone number?",
90
        "What is your email address?",
91
        "What is your address?",
92
        "What is your privacy policy?",
93
    ]
94

95
    # The request will contain the texts and the tags.
96
    request = TagRequest(texts=texts, tags=tags)
97

98
    # The response will contain the texts, the predicted tags, and the confidence.
99
    response = asyncio.run(tag_request(request))
100
    print(response.model_dump_json(indent=2))
101

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.