instructor

Форк
0
/
example_redis.py 
74 строки · 2.0 Кб
1
import redis
2
import functools
3
import inspect
4
import instructor
5

6
from pydantic import BaseModel
7
from openai import OpenAI
8

9
client = instructor.from_openai(OpenAI())
10
cache = redis.Redis("localhost")
11

12

13
def instructor_cache(func):
14
    """Cache a function that returns a Pydantic model"""
15
    return_type = inspect.signature(func).return_annotation
16
    if not issubclass(return_type, BaseModel):
17
        raise ValueError("The return type must be a Pydantic model")
18

19
    @functools.wraps(func)
20
    def wrapper(*args, **kwargs):
21
        key = f"{func.__name__}-{functools._make_key(args, kwargs, typed=False)}"
22
        # Check if the result is already cached
23
        if (cached := cache.get(key)) is not None:
24
            # Deserialize from JSON based on the return type
25
            if issubclass(return_type, BaseModel):
26
                return return_type.model_validate_json(cached)
27

28
        # Call the function and cache its result
29
        result = func(*args, **kwargs)
30
        serialized_result = result.model_dump_json()
31
        cache.set(key, serialized_result)
32

33
        return result
34

35
    return wrapper
36

37

38
class UserDetail(BaseModel):
39
    name: str
40
    age: int
41

42

43
@instructor_cache
44
def extract(data) -> UserDetail:
45
    # Assuming client.chat.completions.create returns a UserDetail instance
46
    return client.chat.completions.create(
47
        model="gpt-3.5-turbo",
48
        response_model=UserDetail,
49
        messages=[
50
            {"role": "user", "content": data},
51
        ],
52
    )
53

54

55
def test_extract():
56
    import time
57

58
    start = time.perf_counter()
59
    model = extract("Extract jason is 25 years old")
60
    assert model.name.lower() == "jason"
61
    assert model.age == 25
62
    print(f"Time taken: {time.perf_counter() - start}")
63

64
    start = time.perf_counter()
65
    model = extract("Extract jason is 25 years old")
66
    assert model.name.lower() == "jason"
67
    assert model.age == 25
68
    print(f"Time taken: {time.perf_counter() - start}")
69

70

71
if __name__ == "__main__":
72
    test_extract()
73
    # Time taken: 0.798335583996959
74
    # Time taken: 0.00017016706988215446
75

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.