instructor
125 строк · 3.5 Кб
1import functools2import inspect3import instructor4import diskcache5
6from openai import OpenAI, AsyncOpenAI7from pydantic import BaseModel8
9client = instructor.from_openai(OpenAI())10aclient = instructor.from_openai(AsyncOpenAI())11
12
13class UserDetail(BaseModel):14name: str15age: int16
17
18cache = diskcache.Cache("./my_cache_directory")19
20
21def instructor_cache(func):22"""Cache a function that returns a Pydantic model"""23return_type = inspect.signature(func).return_annotation24if not issubclass(return_type, BaseModel):25raise ValueError("The return type must be a Pydantic model")26
27is_async = inspect.iscoroutinefunction(func)28
29@functools.wraps(func)30def wrapper(*args, **kwargs):31key = f"{func.__name__}-{functools._make_key(args, kwargs, typed=False)}"32# Check if the result is already cached33if (cached := cache.get(key)) is not None:34# Deserialize from JSON based on the return type35if issubclass(return_type, BaseModel):36return return_type.model_validate_json(cached)37
38# Call the function and cache its result39result = func(*args, **kwargs)40serialized_result = result.model_dump_json()41cache.set(key, serialized_result)42
43return result44
45@functools.wraps(func)46async def awrapper(*args, **kwargs):47key = f"{func.__name__}-{functools._make_key(args, kwargs, typed=False)}"48# Check if the result is already cached49if (cached := cache.get(key)) is not None:50# Deserialize from JSON based on the return type51if issubclass(return_type, BaseModel):52return return_type.model_validate_json(cached)53
54# Call the function and cache its result55result = await func(*args, **kwargs)56serialized_result = result.model_dump_json()57cache.set(key, serialized_result)58
59return result60
61return wrapper if not is_async else awrapper62
63
64@instructor_cache
65def extract(data) -> UserDetail:66return client.chat.completions.create(67model="gpt-3.5-turbo",68response_model=UserDetail,69messages=[70{"role": "user", "content": data},71],72) # type: ignore73
74
75@instructor_cache
76async def aextract(data) -> UserDetail:77return await aclient.chat.completions.create(78model="gpt-3.5-turbo",79response_model=UserDetail,80messages=[81{"role": "user", "content": data},82],83) # type: ignore84
85
86def test_extract():87import time88
89start = time.perf_counter()90model = extract("Extract jason is 25 years old")91assert model.name.lower() == "jason"92assert model.age == 2593print(f"Time taken: {time.perf_counter() - start}")94
95start = time.perf_counter()96model = extract("Extract jason is 25 years old")97assert model.name.lower() == "jason"98assert model.age == 2599print(f"Time taken: {time.perf_counter() - start}")100
101
102async def atest_extract():103import time104
105start = time.perf_counter()106model = await aextract("Extract jason is 25 years old")107assert model.name.lower() == "jason"108assert model.age == 25109print(f"Time taken: {time.perf_counter() - start}")110
111start = time.perf_counter()112model = await aextract("Extract jason is 25 years old")113assert model.name.lower() == "jason"114assert model.age == 25115print(f"Time taken: {time.perf_counter() - start}")116
117
118if __name__ == "__main__":119test_extract()120# Time taken: 0.7285366660216823121# Time taken: 9.841693099588156e-05122
123import asyncio124
125asyncio.run(atest_extract())126