llama-index

Форк
0
348 строк · 12.3 Кб
1
import asyncio
2
from abc import abstractmethod
3
from contextlib import contextmanager
4
from typing import (
5
    Any,
6
    AsyncGenerator,
7
    Callable,
8
    Generator,
9
    Sequence,
10
    cast,
11
)
12

13
from llama_index.legacy.bridge.pydantic import Field, validator
14
from llama_index.legacy.callbacks import CallbackManager, CBEventType, EventPayload
15
from llama_index.legacy.core.llms.types import (
16
    ChatMessage,
17
    ChatResponse,
18
    ChatResponseAsyncGen,
19
    ChatResponseGen,
20
    CompletionResponse,
21
    CompletionResponseAsyncGen,
22
    CompletionResponseGen,
23
    LLMMetadata,
24
)
25
from llama_index.legacy.core.query_pipeline.query_component import (
26
    ChainableMixin,
27
)
28
from llama_index.legacy.schema import BaseComponent
29

30

31
def llm_chat_callback() -> Callable:
32
    def wrap(f: Callable) -> Callable:
33
        @contextmanager
34
        def wrapper_logic(_self: Any) -> Generator[CallbackManager, None, None]:
35
            callback_manager = getattr(_self, "callback_manager", None)
36
            if not isinstance(callback_manager, CallbackManager):
37
                raise ValueError(
38
                    "Cannot use llm_chat_callback on an instance "
39
                    "without a callback_manager attribute."
40
                )
41

42
            yield callback_manager
43

44
        async def wrapped_async_llm_chat(
45
            _self: Any, messages: Sequence[ChatMessage], **kwargs: Any
46
        ) -> Any:
47
            with wrapper_logic(_self) as callback_manager:
48
                event_id = callback_manager.on_event_start(
49
                    CBEventType.LLM,
50
                    payload={
51
                        EventPayload.MESSAGES: messages,
52
                        EventPayload.ADDITIONAL_KWARGS: kwargs,
53
                        EventPayload.SERIALIZED: _self.to_dict(),
54
                    },
55
                )
56

57
                f_return_val = await f(_self, messages, **kwargs)
58
                if isinstance(f_return_val, AsyncGenerator):
59
                    # intercept the generator and add a callback to the end
60
                    async def wrapped_gen() -> ChatResponseAsyncGen:
61
                        last_response = None
62
                        async for x in f_return_val:
63
                            yield cast(ChatResponse, x)
64
                            last_response = x
65

66
                        callback_manager.on_event_end(
67
                            CBEventType.LLM,
68
                            payload={
69
                                EventPayload.MESSAGES: messages,
70
                                EventPayload.RESPONSE: last_response,
71
                            },
72
                            event_id=event_id,
73
                        )
74

75
                    return wrapped_gen()
76
                else:
77
                    callback_manager.on_event_end(
78
                        CBEventType.LLM,
79
                        payload={
80
                            EventPayload.MESSAGES: messages,
81
                            EventPayload.RESPONSE: f_return_val,
82
                        },
83
                        event_id=event_id,
84
                    )
85

86
            return f_return_val
87

88
        def wrapped_llm_chat(
89
            _self: Any, messages: Sequence[ChatMessage], **kwargs: Any
90
        ) -> Any:
91
            with wrapper_logic(_self) as callback_manager:
92
                event_id = callback_manager.on_event_start(
93
                    CBEventType.LLM,
94
                    payload={
95
                        EventPayload.MESSAGES: messages,
96
                        EventPayload.ADDITIONAL_KWARGS: kwargs,
97
                        EventPayload.SERIALIZED: _self.to_dict(),
98
                    },
99
                )
100
                f_return_val = f(_self, messages, **kwargs)
101

102
                if isinstance(f_return_val, Generator):
103
                    # intercept the generator and add a callback to the end
104
                    def wrapped_gen() -> ChatResponseGen:
105
                        last_response = None
106
                        for x in f_return_val:
107
                            yield cast(ChatResponse, x)
108
                            last_response = x
109

110
                        callback_manager.on_event_end(
111
                            CBEventType.LLM,
112
                            payload={
113
                                EventPayload.MESSAGES: messages,
114
                                EventPayload.RESPONSE: last_response,
115
                            },
116
                            event_id=event_id,
117
                        )
118

119
                    return wrapped_gen()
120
                else:
121
                    callback_manager.on_event_end(
122
                        CBEventType.LLM,
123
                        payload={
124
                            EventPayload.MESSAGES: messages,
125
                            EventPayload.RESPONSE: f_return_val,
126
                        },
127
                        event_id=event_id,
128
                    )
129

130
            return f_return_val
131

132
        async def async_dummy_wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any:
133
            return await f(_self, *args, **kwargs)
134

135
        def dummy_wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any:
136
            return f(_self, *args, **kwargs)
137

138
        # check if already wrapped
139
        is_wrapped = getattr(f, "__wrapped__", False)
140
        if not is_wrapped:
141
            f.__wrapped__ = True  # type: ignore
142

143
        if asyncio.iscoroutinefunction(f):
144
            if is_wrapped:
145
                return async_dummy_wrapper
146
            else:
147
                return wrapped_async_llm_chat
148
        else:
149
            if is_wrapped:
150
                return dummy_wrapper
151
            else:
152
                return wrapped_llm_chat
153

154
    return wrap
155

156

157
def llm_completion_callback() -> Callable:
158
    def wrap(f: Callable) -> Callable:
159
        @contextmanager
160
        def wrapper_logic(_self: Any) -> Generator[CallbackManager, None, None]:
161
            callback_manager = getattr(_self, "callback_manager", None)
162
            if not isinstance(callback_manager, CallbackManager):
163
                raise ValueError(
164
                    "Cannot use llm_completion_callback on an instance "
165
                    "without a callback_manager attribute."
166
                )
167

168
            yield callback_manager
169

170
        async def wrapped_async_llm_predict(
171
            _self: Any, *args: Any, **kwargs: Any
172
        ) -> Any:
173
            with wrapper_logic(_self) as callback_manager:
174
                event_id = callback_manager.on_event_start(
175
                    CBEventType.LLM,
176
                    payload={
177
                        EventPayload.PROMPT: args[0],
178
                        EventPayload.ADDITIONAL_KWARGS: kwargs,
179
                        EventPayload.SERIALIZED: _self.to_dict(),
180
                    },
181
                )
182

183
                f_return_val = await f(_self, *args, **kwargs)
184

185
                if isinstance(f_return_val, AsyncGenerator):
186
                    # intercept the generator and add a callback to the end
187
                    async def wrapped_gen() -> CompletionResponseAsyncGen:
188
                        last_response = None
189
                        async for x in f_return_val:
190
                            yield cast(CompletionResponse, x)
191
                            last_response = x
192

193
                        callback_manager.on_event_end(
194
                            CBEventType.LLM,
195
                            payload={
196
                                EventPayload.PROMPT: args[0],
197
                                EventPayload.COMPLETION: last_response,
198
                            },
199
                            event_id=event_id,
200
                        )
201

202
                    return wrapped_gen()
203
                else:
204
                    callback_manager.on_event_end(
205
                        CBEventType.LLM,
206
                        payload={
207
                            EventPayload.PROMPT: args[0],
208
                            EventPayload.RESPONSE: f_return_val,
209
                        },
210
                        event_id=event_id,
211
                    )
212

213
            return f_return_val
214

215
        def wrapped_llm_predict(_self: Any, *args: Any, **kwargs: Any) -> Any:
216
            with wrapper_logic(_self) as callback_manager:
217
                event_id = callback_manager.on_event_start(
218
                    CBEventType.LLM,
219
                    payload={
220
                        EventPayload.PROMPT: args[0],
221
                        EventPayload.ADDITIONAL_KWARGS: kwargs,
222
                        EventPayload.SERIALIZED: _self.to_dict(),
223
                    },
224
                )
225

226
                f_return_val = f(_self, *args, **kwargs)
227
                if isinstance(f_return_val, Generator):
228
                    # intercept the generator and add a callback to the end
229
                    def wrapped_gen() -> CompletionResponseGen:
230
                        last_response = None
231
                        for x in f_return_val:
232
                            yield cast(CompletionResponse, x)
233
                            last_response = x
234

235
                        callback_manager.on_event_end(
236
                            CBEventType.LLM,
237
                            payload={
238
                                EventPayload.PROMPT: args[0],
239
                                EventPayload.COMPLETION: last_response,
240
                            },
241
                            event_id=event_id,
242
                        )
243

244
                    return wrapped_gen()
245
                else:
246
                    callback_manager.on_event_end(
247
                        CBEventType.LLM,
248
                        payload={
249
                            EventPayload.PROMPT: args[0],
250
                            EventPayload.COMPLETION: f_return_val,
251
                        },
252
                        event_id=event_id,
253
                    )
254

255
            return f_return_val
256

257
        async def async_dummy_wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any:
258
            return await f(_self, *args, **kwargs)
259

260
        def dummy_wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any:
261
            return f(_self, *args, **kwargs)
262

263
        # check if already wrapped
264
        is_wrapped = getattr(f, "__wrapped__", False)
265
        if not is_wrapped:
266
            f.__wrapped__ = True  # type: ignore
267

268
        if asyncio.iscoroutinefunction(f):
269
            if is_wrapped:
270
                return async_dummy_wrapper
271
            else:
272
                return wrapped_async_llm_predict
273
        else:
274
            if is_wrapped:
275
                return dummy_wrapper
276
            else:
277
                return wrapped_llm_predict
278

279
    return wrap
280

281

282
class BaseLLM(ChainableMixin, BaseComponent):
283
    """LLM interface."""
284

285
    callback_manager: CallbackManager = Field(
286
        default_factory=CallbackManager, exclude=True
287
    )
288

289
    class Config:
290
        arbitrary_types_allowed = True
291

292
    @validator("callback_manager", pre=True)
293
    def _validate_callback_manager(cls, v: CallbackManager) -> CallbackManager:
294
        if v is None:
295
            return CallbackManager([])
296
        return v
297

298
    @property
299
    @abstractmethod
300
    def metadata(self) -> LLMMetadata:
301
        """LLM metadata."""
302

303
    @abstractmethod
304
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
305
        """Chat endpoint for LLM."""
306

307
    @abstractmethod
308
    def complete(
309
        self, prompt: str, formatted: bool = False, **kwargs: Any
310
    ) -> CompletionResponse:
311
        """Completion endpoint for LLM."""
312

313
    @abstractmethod
314
    def stream_chat(
315
        self, messages: Sequence[ChatMessage], **kwargs: Any
316
    ) -> ChatResponseGen:
317
        """Streaming chat endpoint for LLM."""
318

319
    @abstractmethod
320
    def stream_complete(
321
        self, prompt: str, formatted: bool = False, **kwargs: Any
322
    ) -> CompletionResponseGen:
323
        """Streaming completion endpoint for LLM."""
324

325
    # ===== Async Endpoints =====
326
    @abstractmethod
327
    async def achat(
328
        self, messages: Sequence[ChatMessage], **kwargs: Any
329
    ) -> ChatResponse:
330
        """Async chat endpoint for LLM."""
331

332
    @abstractmethod
333
    async def acomplete(
334
        self, prompt: str, formatted: bool = False, **kwargs: Any
335
    ) -> CompletionResponse:
336
        """Async completion endpoint for LLM."""
337

338
    @abstractmethod
339
    async def astream_chat(
340
        self, messages: Sequence[ChatMessage], **kwargs: Any
341
    ) -> ChatResponseAsyncGen:
342
        """Async streaming chat endpoint for LLM."""
343

344
    @abstractmethod
345
    async def astream_complete(
346
        self, prompt: str, formatted: bool = False, **kwargs: Any
347
    ) -> CompletionResponseAsyncGen:
348
        """Async streaming completion endpoint for LLM."""
349

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.