llama-index
348 строк · 12.3 Кб
1import asyncio
2from abc import abstractmethod
3from contextlib import contextmanager
4from typing import (
5Any,
6AsyncGenerator,
7Callable,
8Generator,
9Sequence,
10cast,
11)
12
13from llama_index.legacy.bridge.pydantic import Field, validator
14from llama_index.legacy.callbacks import CallbackManager, CBEventType, EventPayload
15from llama_index.legacy.core.llms.types import (
16ChatMessage,
17ChatResponse,
18ChatResponseAsyncGen,
19ChatResponseGen,
20CompletionResponse,
21CompletionResponseAsyncGen,
22CompletionResponseGen,
23LLMMetadata,
24)
25from llama_index.legacy.core.query_pipeline.query_component import (
26ChainableMixin,
27)
28from llama_index.legacy.schema import BaseComponent
29
30
31def llm_chat_callback() -> Callable:
32def wrap(f: Callable) -> Callable:
33@contextmanager
34def wrapper_logic(_self: Any) -> Generator[CallbackManager, None, None]:
35callback_manager = getattr(_self, "callback_manager", None)
36if not isinstance(callback_manager, CallbackManager):
37raise ValueError(
38"Cannot use llm_chat_callback on an instance "
39"without a callback_manager attribute."
40)
41
42yield callback_manager
43
44async def wrapped_async_llm_chat(
45_self: Any, messages: Sequence[ChatMessage], **kwargs: Any
46) -> Any:
47with wrapper_logic(_self) as callback_manager:
48event_id = callback_manager.on_event_start(
49CBEventType.LLM,
50payload={
51EventPayload.MESSAGES: messages,
52EventPayload.ADDITIONAL_KWARGS: kwargs,
53EventPayload.SERIALIZED: _self.to_dict(),
54},
55)
56
57f_return_val = await f(_self, messages, **kwargs)
58if isinstance(f_return_val, AsyncGenerator):
59# intercept the generator and add a callback to the end
60async def wrapped_gen() -> ChatResponseAsyncGen:
61last_response = None
62async for x in f_return_val:
63yield cast(ChatResponse, x)
64last_response = x
65
66callback_manager.on_event_end(
67CBEventType.LLM,
68payload={
69EventPayload.MESSAGES: messages,
70EventPayload.RESPONSE: last_response,
71},
72event_id=event_id,
73)
74
75return wrapped_gen()
76else:
77callback_manager.on_event_end(
78CBEventType.LLM,
79payload={
80EventPayload.MESSAGES: messages,
81EventPayload.RESPONSE: f_return_val,
82},
83event_id=event_id,
84)
85
86return f_return_val
87
88def wrapped_llm_chat(
89_self: Any, messages: Sequence[ChatMessage], **kwargs: Any
90) -> Any:
91with wrapper_logic(_self) as callback_manager:
92event_id = callback_manager.on_event_start(
93CBEventType.LLM,
94payload={
95EventPayload.MESSAGES: messages,
96EventPayload.ADDITIONAL_KWARGS: kwargs,
97EventPayload.SERIALIZED: _self.to_dict(),
98},
99)
100f_return_val = f(_self, messages, **kwargs)
101
102if isinstance(f_return_val, Generator):
103# intercept the generator and add a callback to the end
104def wrapped_gen() -> ChatResponseGen:
105last_response = None
106for x in f_return_val:
107yield cast(ChatResponse, x)
108last_response = x
109
110callback_manager.on_event_end(
111CBEventType.LLM,
112payload={
113EventPayload.MESSAGES: messages,
114EventPayload.RESPONSE: last_response,
115},
116event_id=event_id,
117)
118
119return wrapped_gen()
120else:
121callback_manager.on_event_end(
122CBEventType.LLM,
123payload={
124EventPayload.MESSAGES: messages,
125EventPayload.RESPONSE: f_return_val,
126},
127event_id=event_id,
128)
129
130return f_return_val
131
132async def async_dummy_wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any:
133return await f(_self, *args, **kwargs)
134
135def dummy_wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any:
136return f(_self, *args, **kwargs)
137
138# check if already wrapped
139is_wrapped = getattr(f, "__wrapped__", False)
140if not is_wrapped:
141f.__wrapped__ = True # type: ignore
142
143if asyncio.iscoroutinefunction(f):
144if is_wrapped:
145return async_dummy_wrapper
146else:
147return wrapped_async_llm_chat
148else:
149if is_wrapped:
150return dummy_wrapper
151else:
152return wrapped_llm_chat
153
154return wrap
155
156
157def llm_completion_callback() -> Callable:
158def wrap(f: Callable) -> Callable:
159@contextmanager
160def wrapper_logic(_self: Any) -> Generator[CallbackManager, None, None]:
161callback_manager = getattr(_self, "callback_manager", None)
162if not isinstance(callback_manager, CallbackManager):
163raise ValueError(
164"Cannot use llm_completion_callback on an instance "
165"without a callback_manager attribute."
166)
167
168yield callback_manager
169
170async def wrapped_async_llm_predict(
171_self: Any, *args: Any, **kwargs: Any
172) -> Any:
173with wrapper_logic(_self) as callback_manager:
174event_id = callback_manager.on_event_start(
175CBEventType.LLM,
176payload={
177EventPayload.PROMPT: args[0],
178EventPayload.ADDITIONAL_KWARGS: kwargs,
179EventPayload.SERIALIZED: _self.to_dict(),
180},
181)
182
183f_return_val = await f(_self, *args, **kwargs)
184
185if isinstance(f_return_val, AsyncGenerator):
186# intercept the generator and add a callback to the end
187async def wrapped_gen() -> CompletionResponseAsyncGen:
188last_response = None
189async for x in f_return_val:
190yield cast(CompletionResponse, x)
191last_response = x
192
193callback_manager.on_event_end(
194CBEventType.LLM,
195payload={
196EventPayload.PROMPT: args[0],
197EventPayload.COMPLETION: last_response,
198},
199event_id=event_id,
200)
201
202return wrapped_gen()
203else:
204callback_manager.on_event_end(
205CBEventType.LLM,
206payload={
207EventPayload.PROMPT: args[0],
208EventPayload.RESPONSE: f_return_val,
209},
210event_id=event_id,
211)
212
213return f_return_val
214
215def wrapped_llm_predict(_self: Any, *args: Any, **kwargs: Any) -> Any:
216with wrapper_logic(_self) as callback_manager:
217event_id = callback_manager.on_event_start(
218CBEventType.LLM,
219payload={
220EventPayload.PROMPT: args[0],
221EventPayload.ADDITIONAL_KWARGS: kwargs,
222EventPayload.SERIALIZED: _self.to_dict(),
223},
224)
225
226f_return_val = f(_self, *args, **kwargs)
227if isinstance(f_return_val, Generator):
228# intercept the generator and add a callback to the end
229def wrapped_gen() -> CompletionResponseGen:
230last_response = None
231for x in f_return_val:
232yield cast(CompletionResponse, x)
233last_response = x
234
235callback_manager.on_event_end(
236CBEventType.LLM,
237payload={
238EventPayload.PROMPT: args[0],
239EventPayload.COMPLETION: last_response,
240},
241event_id=event_id,
242)
243
244return wrapped_gen()
245else:
246callback_manager.on_event_end(
247CBEventType.LLM,
248payload={
249EventPayload.PROMPT: args[0],
250EventPayload.COMPLETION: f_return_val,
251},
252event_id=event_id,
253)
254
255return f_return_val
256
257async def async_dummy_wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any:
258return await f(_self, *args, **kwargs)
259
260def dummy_wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any:
261return f(_self, *args, **kwargs)
262
263# check if already wrapped
264is_wrapped = getattr(f, "__wrapped__", False)
265if not is_wrapped:
266f.__wrapped__ = True # type: ignore
267
268if asyncio.iscoroutinefunction(f):
269if is_wrapped:
270return async_dummy_wrapper
271else:
272return wrapped_async_llm_predict
273else:
274if is_wrapped:
275return dummy_wrapper
276else:
277return wrapped_llm_predict
278
279return wrap
280
281
282class BaseLLM(ChainableMixin, BaseComponent):
283"""LLM interface."""
284
285callback_manager: CallbackManager = Field(
286default_factory=CallbackManager, exclude=True
287)
288
289class Config:
290arbitrary_types_allowed = True
291
292@validator("callback_manager", pre=True)
293def _validate_callback_manager(cls, v: CallbackManager) -> CallbackManager:
294if v is None:
295return CallbackManager([])
296return v
297
298@property
299@abstractmethod
300def metadata(self) -> LLMMetadata:
301"""LLM metadata."""
302
303@abstractmethod
304def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
305"""Chat endpoint for LLM."""
306
307@abstractmethod
308def complete(
309self, prompt: str, formatted: bool = False, **kwargs: Any
310) -> CompletionResponse:
311"""Completion endpoint for LLM."""
312
313@abstractmethod
314def stream_chat(
315self, messages: Sequence[ChatMessage], **kwargs: Any
316) -> ChatResponseGen:
317"""Streaming chat endpoint for LLM."""
318
319@abstractmethod
320def stream_complete(
321self, prompt: str, formatted: bool = False, **kwargs: Any
322) -> CompletionResponseGen:
323"""Streaming completion endpoint for LLM."""
324
325# ===== Async Endpoints =====
326@abstractmethod
327async def achat(
328self, messages: Sequence[ChatMessage], **kwargs: Any
329) -> ChatResponse:
330"""Async chat endpoint for LLM."""
331
332@abstractmethod
333async def acomplete(
334self, prompt: str, formatted: bool = False, **kwargs: Any
335) -> CompletionResponse:
336"""Async completion endpoint for LLM."""
337
338@abstractmethod
339async def astream_chat(
340self, messages: Sequence[ChatMessage], **kwargs: Any
341) -> ChatResponseAsyncGen:
342"""Async streaming chat endpoint for LLM."""
343
344@abstractmethod
345async def astream_complete(
346self, prompt: str, formatted: bool = False, **kwargs: Any
347) -> CompletionResponseAsyncGen:
348"""Async streaming completion endpoint for LLM."""
349