llama-index

Форк
0
216 строк · 7.3 Кб
1
from __future__ import annotations
2

3
from typing import Any, Tuple, cast
4

5
from deprecated import deprecated
6

7
from llama_index.legacy.bridge.pydantic import PrivateAttr
8
from llama_index.legacy.callbacks import CallbackManager
9
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
10
from llama_index.legacy.llm_predictor.base import LLM, BaseLLMPredictor, LLMMetadata
11
from llama_index.legacy.llm_predictor.vellum.exceptions import VellumGenerateException
12
from llama_index.legacy.llm_predictor.vellum.prompt_registry import VellumPromptRegistry
13
from llama_index.legacy.llm_predictor.vellum.types import (
14
    VellumCompiledPrompt,
15
    VellumRegisteredPrompt,
16
)
17
from llama_index.legacy.prompts import BasePromptTemplate
18
from llama_index.legacy.types import TokenAsyncGen, TokenGen
19

20

21
@deprecated("VellumPredictor is deprecated and will be removed in a future release.")
22
class VellumPredictor(BaseLLMPredictor):
23
    _callback_manager: CallbackManager = PrivateAttr(default_factory=CallbackManager)
24

25
    _vellum_client: Any = PrivateAttr()
26
    _async_vellum_client = PrivateAttr()
27
    _prompt_registry: Any = PrivateAttr()
28

29
    class Config:
30
        arbitrary_types_allowed = True
31

32
    def __init__(
33
        self,
34
        vellum_api_key: str,
35
        callback_manager: CallbackManager | None = None,
36
    ) -> None:
37
        import_err_msg = (
38
            "`vellum` package not found, please run `pip install vellum-ai`"
39
        )
40
        try:
41
            from vellum.client import AsyncVellum, Vellum
42
        except ImportError:
43
            raise ImportError(import_err_msg)
44

45
        self._callback_manager = callback_manager or CallbackManager([])
46

47
        # Vellum-specific
48
        self._vellum_client = Vellum(api_key=vellum_api_key)
49
        self._async_vellum_client = AsyncVellum(api_key=vellum_api_key)
50
        self._prompt_registry = VellumPromptRegistry(vellum_api_key=vellum_api_key)
51

52
        super().__init__()
53

54
    @classmethod
55
    def class_name(cls) -> str:
56
        return "VellumPredictor"
57

58
    @property
59
    def metadata(self) -> LLMMetadata:
60
        """Get LLM metadata."""
61
        # Note: We use default values here, but ideally we would retrieve this metadata
62
        # via Vellum's API based on the LLM that backs the registered prompt's
63
        # deployment. This is not currently possible, so we use default values.
64
        return LLMMetadata()
65

66
    @property
67
    def callback_manager(self) -> CallbackManager:
68
        """Get callback manager."""
69
        return self._callback_manager
70

71
    @property
72
    def llm(self) -> LLM:
73
        """Get the LLM."""
74
        raise NotImplementedError("Vellum does not expose the LLM.")
75

76
    def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
77
        """Predict the answer to a query."""
78
        from vellum import GenerateRequest
79

80
        registered_prompt, compiled_prompt, event_id = self._prepare_generate_call(
81
            prompt, **prompt_args
82
        )
83

84
        input_values = {
85
            **prompt.kwargs,
86
            **prompt_args,
87
        }
88
        result = self._vellum_client.generate(
89
            deployment_id=registered_prompt.deployment_id,
90
            requests=[GenerateRequest(input_values=input_values)],
91
        )
92

93
        return self._process_generate_response(result, compiled_prompt, event_id)
94

95
    def stream(self, prompt: BasePromptTemplate, **prompt_args: Any) -> TokenGen:
96
        """Stream the answer to a query."""
97
        from vellum import GenerateRequest, GenerateStreamResult
98

99
        registered_prompt, compiled_prompt, event_id = self._prepare_generate_call(
100
            prompt, **prompt_args
101
        )
102

103
        input_values = {
104
            **prompt.kwargs,
105
            **prompt_args,
106
        }
107
        responses = self._vellum_client.generate_stream(
108
            deployment_id=registered_prompt.deployment_id,
109
            requests=[GenerateRequest(input_values=input_values)],
110
        )
111

112
        def text_generator() -> TokenGen:
113
            complete_text = ""
114

115
            while True:
116
                try:
117
                    stream_response = next(responses)
118
                except StopIteration:
119
                    self.callback_manager.on_event_end(
120
                        CBEventType.LLM,
121
                        payload={
122
                            EventPayload.RESPONSE: complete_text,
123
                            EventPayload.PROMPT: compiled_prompt.text,
124
                        },
125
                        event_id=event_id,
126
                    )
127
                    break
128

129
                result: GenerateStreamResult = stream_response.delta
130

131
                if result.error:
132
                    raise VellumGenerateException(result.error.message)
133
                elif not result.data:
134
                    raise VellumGenerateException(
135
                        "Unknown error occurred while generating"
136
                    )
137

138
                completion_text_delta = result.data.completion.text
139
                complete_text += completion_text_delta
140

141
                yield completion_text_delta
142

143
        return text_generator()
144

145
    async def apredict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
146
        """Asynchronously predict the answer to a query."""
147
        from vellum import GenerateRequest
148

149
        registered_prompt, compiled_prompt, event_id = self._prepare_generate_call(
150
            prompt, **prompt_args
151
        )
152

153
        input_values = {
154
            **prompt.kwargs,
155
            **prompt_args,
156
        }
157
        result = await self._async_vellum_client.generate(
158
            deployment_id=registered_prompt.deployment_id,
159
            requests=[GenerateRequest(input_values=input_values)],
160
        )
161

162
        return self._process_generate_response(result, compiled_prompt, event_id)
163

164
    async def astream(
165
        self, prompt: BasePromptTemplate, **prompt_args: Any
166
    ) -> TokenAsyncGen:
167
        async def gen() -> TokenAsyncGen:
168
            for token in self.stream(prompt, **prompt_args):
169
                yield token
170

171
        # NOTE: convert generator to async generator
172
        return gen()
173

174
    def _prepare_generate_call(
175
        self, prompt: BasePromptTemplate, **prompt_args: Any
176
    ) -> Tuple[VellumRegisteredPrompt, VellumCompiledPrompt, str]:
177
        """Prepare a generate call."""
178
        registered_prompt = self._prompt_registry.from_prompt(prompt)
179
        compiled_prompt = self._prompt_registry.get_compiled_prompt(
180
            registered_prompt, prompt_args
181
        )
182

183
        cb_payload = {
184
            **prompt_args,
185
            "deployment_id": registered_prompt.deployment_id,
186
            "model_version_id": registered_prompt.model_version_id,
187
        }
188
        event_id = self.callback_manager.on_event_start(
189
            CBEventType.LLM,
190
            payload=cb_payload,
191
        )
192
        return registered_prompt, compiled_prompt, event_id
193

194
    def _process_generate_response(
195
        self,
196
        result: Any,
197
        compiled_prompt: VellumCompiledPrompt,
198
        event_id: str,
199
    ) -> str:
200
        """Process the response from a generate call."""
201
        from vellum import GenerateResponse
202

203
        result = cast(GenerateResponse, result)
204

205
        completion_text = result.text
206

207
        self.callback_manager.on_event_end(
208
            CBEventType.LLM,
209
            payload={
210
                EventPayload.RESPONSE: completion_text,
211
                EventPayload.PROMPT: compiled_prompt.text,
212
            },
213
            event_id=event_id,
214
        )
215

216
        return completion_text
217

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.