instructor

Форк
0
/
example_diskcache.py 
125 строк · 3.5 Кб
1
import functools
2
import inspect
3
import instructor
4
import diskcache
5

6
from openai import OpenAI, AsyncOpenAI
7
from pydantic import BaseModel
8

9
client = instructor.from_openai(OpenAI())
10
aclient = instructor.from_openai(AsyncOpenAI())
11

12

13
class UserDetail(BaseModel):
14
    name: str
15
    age: int
16

17

18
cache = diskcache.Cache("./my_cache_directory")
19

20

21
def instructor_cache(func):
22
    """Cache a function that returns a Pydantic model"""
23
    return_type = inspect.signature(func).return_annotation
24
    if not issubclass(return_type, BaseModel):
25
        raise ValueError("The return type must be a Pydantic model")
26

27
    is_async = inspect.iscoroutinefunction(func)
28

29
    @functools.wraps(func)
30
    def wrapper(*args, **kwargs):
31
        key = f"{func.__name__}-{functools._make_key(args, kwargs, typed=False)}"
32
        # Check if the result is already cached
33
        if (cached := cache.get(key)) is not None:
34
            # Deserialize from JSON based on the return type
35
            if issubclass(return_type, BaseModel):
36
                return return_type.model_validate_json(cached)
37

38
        # Call the function and cache its result
39
        result = func(*args, **kwargs)
40
        serialized_result = result.model_dump_json()
41
        cache.set(key, serialized_result)
42

43
        return result
44

45
    @functools.wraps(func)
46
    async def awrapper(*args, **kwargs):
47
        key = f"{func.__name__}-{functools._make_key(args, kwargs, typed=False)}"
48
        # Check if the result is already cached
49
        if (cached := cache.get(key)) is not None:
50
            # Deserialize from JSON based on the return type
51
            if issubclass(return_type, BaseModel):
52
                return return_type.model_validate_json(cached)
53

54
        # Call the function and cache its result
55
        result = await func(*args, **kwargs)
56
        serialized_result = result.model_dump_json()
57
        cache.set(key, serialized_result)
58

59
        return result
60

61
    return wrapper if not is_async else awrapper
62

63

64
@instructor_cache
65
def extract(data) -> UserDetail:
66
    return client.chat.completions.create(
67
        model="gpt-3.5-turbo",
68
        response_model=UserDetail,
69
        messages=[
70
            {"role": "user", "content": data},
71
        ],
72
    )  # type: ignore
73

74

75
@instructor_cache
76
async def aextract(data) -> UserDetail:
77
    return await aclient.chat.completions.create(
78
        model="gpt-3.5-turbo",
79
        response_model=UserDetail,
80
        messages=[
81
            {"role": "user", "content": data},
82
        ],
83
    )  # type: ignore
84

85

86
def test_extract():
87
    import time
88

89
    start = time.perf_counter()
90
    model = extract("Extract jason is 25 years old")
91
    assert model.name.lower() == "jason"
92
    assert model.age == 25
93
    print(f"Time taken: {time.perf_counter() - start}")
94

95
    start = time.perf_counter()
96
    model = extract("Extract jason is 25 years old")
97
    assert model.name.lower() == "jason"
98
    assert model.age == 25
99
    print(f"Time taken: {time.perf_counter() - start}")
100

101

102
async def atest_extract():
103
    import time
104

105
    start = time.perf_counter()
106
    model = await aextract("Extract jason is 25 years old")
107
    assert model.name.lower() == "jason"
108
    assert model.age == 25
109
    print(f"Time taken: {time.perf_counter() - start}")
110

111
    start = time.perf_counter()
112
    model = await aextract("Extract jason is 25 years old")
113
    assert model.name.lower() == "jason"
114
    assert model.age == 25
115
    print(f"Time taken: {time.perf_counter() - start}")
116

117

118
if __name__ == "__main__":
119
    test_extract()
120
    # Time taken: 0.7285366660216823
121
    # Time taken: 9.841693099588156e-05
122

123
    import asyncio
124

125
    asyncio.run(atest_extract())
126

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.