llama-index

Форк
0
461 строка · 15.1 Кб
1
from collections import ChainMap
2
from typing import (
3
    Any,
4
    Dict,
5
    List,
6
    Optional,
7
    Protocol,
8
    Sequence,
9
    get_args,
10
    runtime_checkable,
11
)
12

13
from llama_index.legacy.bridge.pydantic import BaseModel, Field, validator
14
from llama_index.legacy.callbacks import CBEventType, EventPayload
15
from llama_index.legacy.core.llms.types import (
16
    ChatMessage,
17
    ChatResponseAsyncGen,
18
    ChatResponseGen,
19
    CompletionResponseAsyncGen,
20
    CompletionResponseGen,
21
    MessageRole,
22
)
23
from llama_index.legacy.core.query_pipeline.query_component import (
24
    InputKeys,
25
    OutputKeys,
26
    QueryComponent,
27
    StringableInput,
28
    validate_and_convert_stringable,
29
)
30
from llama_index.legacy.llms.base import BaseLLM
31
from llama_index.legacy.llms.generic_utils import (
32
    messages_to_prompt as generic_messages_to_prompt,
33
)
34
from llama_index.legacy.llms.generic_utils import (
35
    prompt_to_messages,
36
)
37
from llama_index.legacy.prompts import BasePromptTemplate, PromptTemplate
38
from llama_index.legacy.types import (
39
    BaseOutputParser,
40
    PydanticProgramMode,
41
    TokenAsyncGen,
42
    TokenGen,
43
)
44

45

46
# NOTE: These two protocols are needed to appease mypy
47
@runtime_checkable
48
class MessagesToPromptType(Protocol):
49
    def __call__(self, messages: Sequence[ChatMessage]) -> str:
50
        pass
51

52

53
@runtime_checkable
54
class CompletionToPromptType(Protocol):
55
    def __call__(self, prompt: str) -> str:
56
        pass
57

58

59
def stream_completion_response_to_tokens(
60
    completion_response_gen: CompletionResponseGen,
61
) -> TokenGen:
62
    """Convert a stream completion response to a stream of tokens."""
63

64
    def gen() -> TokenGen:
65
        for response in completion_response_gen:
66
            yield response.delta or ""
67

68
    return gen()
69

70

71
def stream_chat_response_to_tokens(
72
    chat_response_gen: ChatResponseGen,
73
) -> TokenGen:
74
    """Convert a stream completion response to a stream of tokens."""
75

76
    def gen() -> TokenGen:
77
        for response in chat_response_gen:
78
            yield response.delta or ""
79

80
    return gen()
81

82

83
async def astream_completion_response_to_tokens(
84
    completion_response_gen: CompletionResponseAsyncGen,
85
) -> TokenAsyncGen:
86
    """Convert a stream completion response to a stream of tokens."""
87

88
    async def gen() -> TokenAsyncGen:
89
        async for response in completion_response_gen:
90
            yield response.delta or ""
91

92
    return gen()
93

94

95
async def astream_chat_response_to_tokens(
96
    chat_response_gen: ChatResponseAsyncGen,
97
) -> TokenAsyncGen:
98
    """Convert a stream completion response to a stream of tokens."""
99

100
    async def gen() -> TokenAsyncGen:
101
        async for response in chat_response_gen:
102
            yield response.delta or ""
103

104
    return gen()
105

106

107
def default_completion_to_prompt(prompt: str) -> str:
108
    return prompt
109

110

111
class LLM(BaseLLM):
112
    system_prompt: Optional[str] = Field(
113
        default=None, description="System prompt for LLM calls."
114
    )
115
    messages_to_prompt: MessagesToPromptType = Field(
116
        description="Function to convert a list of messages to an LLM prompt.",
117
        default=generic_messages_to_prompt,
118
        exclude=True,
119
    )
120
    completion_to_prompt: CompletionToPromptType = Field(
121
        description="Function to convert a completion to an LLM prompt.",
122
        default=default_completion_to_prompt,
123
        exclude=True,
124
    )
125
    output_parser: Optional[BaseOutputParser] = Field(
126
        description="Output parser to parse, validate, and correct errors programmatically.",
127
        default=None,
128
        exclude=True,
129
    )
130
    pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT
131

132
    # deprecated
133
    query_wrapper_prompt: Optional[BasePromptTemplate] = Field(
134
        description="Query wrapper prompt for LLM calls.",
135
        default=None,
136
        exclude=True,
137
    )
138

139
    @validator("messages_to_prompt", pre=True)
140
    def set_messages_to_prompt(
141
        cls, messages_to_prompt: Optional[MessagesToPromptType]
142
    ) -> MessagesToPromptType:
143
        return messages_to_prompt or generic_messages_to_prompt
144

145
    @validator("completion_to_prompt", pre=True)
146
    def set_completion_to_prompt(
147
        cls, completion_to_prompt: Optional[CompletionToPromptType]
148
    ) -> CompletionToPromptType:
149
        return completion_to_prompt or default_completion_to_prompt
150

151
    def _log_template_data(
152
        self, prompt: BasePromptTemplate, **prompt_args: Any
153
    ) -> None:
154
        template_vars = {
155
            k: v
156
            for k, v in ChainMap(prompt.kwargs, prompt_args).items()
157
            if k in prompt.template_vars
158
        }
159
        with self.callback_manager.event(
160
            CBEventType.TEMPLATING,
161
            payload={
162
                EventPayload.TEMPLATE: prompt.get_template(llm=self),
163
                EventPayload.TEMPLATE_VARS: template_vars,
164
                EventPayload.SYSTEM_PROMPT: self.system_prompt,
165
                EventPayload.QUERY_WRAPPER_PROMPT: self.query_wrapper_prompt,
166
            },
167
        ):
168
            pass
169

170
    def _get_prompt(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
171
        formatted_prompt = prompt.format(
172
            llm=self,
173
            messages_to_prompt=self.messages_to_prompt,
174
            completion_to_prompt=self.completion_to_prompt,
175
            **prompt_args,
176
        )
177
        if self.output_parser is not None:
178
            formatted_prompt = self.output_parser.format(formatted_prompt)
179
        return self._extend_prompt(formatted_prompt)
180

181
    def _get_messages(
182
        self, prompt: BasePromptTemplate, **prompt_args: Any
183
    ) -> List[ChatMessage]:
184
        messages = prompt.format_messages(llm=self, **prompt_args)
185
        if self.output_parser is not None:
186
            messages = self.output_parser.format_messages(messages)
187
        return self._extend_messages(messages)
188

189
    def structured_predict(
190
        self,
191
        output_cls: BaseModel,
192
        prompt: PromptTemplate,
193
        **prompt_args: Any,
194
    ) -> BaseModel:
195
        from llama_index.legacy.program.utils import get_program_for_llm
196

197
        program = get_program_for_llm(
198
            output_cls,
199
            prompt,
200
            self,
201
            pydantic_program_mode=self.pydantic_program_mode,
202
        )
203

204
        return program(**prompt_args)
205

206
    async def astructured_predict(
207
        self,
208
        output_cls: BaseModel,
209
        prompt: PromptTemplate,
210
        **prompt_args: Any,
211
    ) -> BaseModel:
212
        from llama_index.legacy.program.utils import get_program_for_llm
213

214
        program = get_program_for_llm(
215
            output_cls,
216
            prompt,
217
            self,
218
            pydantic_program_mode=self.pydantic_program_mode,
219
        )
220

221
        return await program.acall(**prompt_args)
222

223
    def _parse_output(self, output: str) -> str:
224
        if self.output_parser is not None:
225
            return str(self.output_parser.parse(output))
226

227
        return output
228

229
    def predict(
230
        self,
231
        prompt: BasePromptTemplate,
232
        **prompt_args: Any,
233
    ) -> str:
234
        """Predict."""
235
        self._log_template_data(prompt, **prompt_args)
236

237
        if self.metadata.is_chat_model:
238
            messages = self._get_messages(prompt, **prompt_args)
239
            chat_response = self.chat(messages)
240
            output = chat_response.message.content or ""
241
        else:
242
            formatted_prompt = self._get_prompt(prompt, **prompt_args)
243
            response = self.complete(formatted_prompt, formatted=True)
244
            output = response.text
245

246
        return self._parse_output(output)
247

248
    def stream(
249
        self,
250
        prompt: BasePromptTemplate,
251
        **prompt_args: Any,
252
    ) -> TokenGen:
253
        """Stream."""
254
        self._log_template_data(prompt, **prompt_args)
255

256
        if self.metadata.is_chat_model:
257
            messages = self._get_messages(prompt, **prompt_args)
258
            chat_response = self.stream_chat(messages)
259
            stream_tokens = stream_chat_response_to_tokens(chat_response)
260
        else:
261
            formatted_prompt = self._get_prompt(prompt, **prompt_args)
262
            stream_response = self.stream_complete(formatted_prompt, formatted=True)
263
            stream_tokens = stream_completion_response_to_tokens(stream_response)
264

265
        if prompt.output_parser is not None or self.output_parser is not None:
266
            raise NotImplementedError("Output parser is not supported for streaming.")
267

268
        return stream_tokens
269

270
    async def apredict(
271
        self,
272
        prompt: BasePromptTemplate,
273
        **prompt_args: Any,
274
    ) -> str:
275
        """Async predict."""
276
        self._log_template_data(prompt, **prompt_args)
277

278
        if self.metadata.is_chat_model:
279
            messages = self._get_messages(prompt, **prompt_args)
280
            chat_response = await self.achat(messages)
281
            output = chat_response.message.content or ""
282
        else:
283
            formatted_prompt = self._get_prompt(prompt, **prompt_args)
284
            response = await self.acomplete(formatted_prompt, formatted=True)
285
            output = response.text
286

287
        return self._parse_output(output)
288

289
    async def astream(
290
        self,
291
        prompt: BasePromptTemplate,
292
        **prompt_args: Any,
293
    ) -> TokenAsyncGen:
294
        """Async stream."""
295
        self._log_template_data(prompt, **prompt_args)
296

297
        if self.metadata.is_chat_model:
298
            messages = self._get_messages(prompt, **prompt_args)
299
            chat_response = await self.astream_chat(messages)
300
            stream_tokens = await astream_chat_response_to_tokens(chat_response)
301
        else:
302
            formatted_prompt = self._get_prompt(prompt, **prompt_args)
303
            stream_response = await self.astream_complete(
304
                formatted_prompt, formatted=True
305
            )
306
            stream_tokens = await astream_completion_response_to_tokens(stream_response)
307

308
        if prompt.output_parser is not None or self.output_parser is not None:
309
            raise NotImplementedError("Output parser is not supported for streaming.")
310

311
        return stream_tokens
312

313
    def _extend_prompt(
314
        self,
315
        formatted_prompt: str,
316
    ) -> str:
317
        """Add system and query wrapper prompts to base prompt."""
318
        extended_prompt = formatted_prompt
319

320
        if self.system_prompt:
321
            extended_prompt = self.system_prompt + "\n\n" + extended_prompt
322

323
        if self.query_wrapper_prompt:
324
            extended_prompt = self.query_wrapper_prompt.format(
325
                query_str=extended_prompt
326
            )
327

328
        return extended_prompt
329

330
    def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
331
        """Add system prompt to chat message list."""
332
        if self.system_prompt:
333
            messages = [
334
                ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt),
335
                *messages,
336
            ]
337
        return messages
338

339
    def _as_query_component(self, **kwargs: Any) -> QueryComponent:
340
        """Return query component."""
341
        if self.metadata.is_chat_model:
342
            return LLMChatComponent(llm=self, **kwargs)
343
        else:
344
            return LLMCompleteComponent(llm=self, **kwargs)
345

346

347
class BaseLLMComponent(QueryComponent):
348
    """Base LLM component."""
349

350
    llm: LLM = Field(..., description="LLM")
351
    streaming: bool = Field(default=False, description="Streaming mode")
352

353
    class Config:
354
        arbitrary_types_allowed = True
355

356
    def set_callback_manager(self, callback_manager: Any) -> None:
357
        """Set callback manager."""
358
        self.llm.callback_manager = callback_manager
359

360

361
class LLMCompleteComponent(BaseLLMComponent):
362
    """LLM completion component."""
363

364
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
365
        """Validate component inputs during run_component."""
366
        if "prompt" not in input:
367
            raise ValueError("Prompt must be in input dict.")
368

369
        # do special check to see if prompt is a list of chat messages
370
        if isinstance(input["prompt"], get_args(List[ChatMessage])):
371
            input["prompt"] = self.llm.messages_to_prompt(input["prompt"])
372
            input["prompt"] = validate_and_convert_stringable(input["prompt"])
373
        else:
374
            input["prompt"] = validate_and_convert_stringable(input["prompt"])
375
            input["prompt"] = self.llm.completion_to_prompt(input["prompt"])
376

377
        return input
378

379
    def _run_component(self, **kwargs: Any) -> Any:
380
        """Run component."""
381
        # TODO: support only complete for now
382
        # non-trivial to figure how to support chat/complete/etc.
383
        prompt = kwargs["prompt"]
384
        # ignore all other kwargs for now
385
        if self.streaming:
386
            response = self.llm.stream_complete(prompt, formatted=True)
387
        else:
388
            response = self.llm.complete(prompt, formatted=True)
389
        return {"output": response}
390

391
    async def _arun_component(self, **kwargs: Any) -> Any:
392
        """Run component."""
393
        # TODO: support only complete for now
394
        # non-trivial to figure how to support chat/complete/etc.
395
        prompt = kwargs["prompt"]
396
        # ignore all other kwargs for now
397
        response = await self.llm.acomplete(prompt, formatted=True)
398
        return {"output": response}
399

400
    @property
401
    def input_keys(self) -> InputKeys:
402
        """Input keys."""
403
        # TODO: support only complete for now
404
        return InputKeys.from_keys({"prompt"})
405

406
    @property
407
    def output_keys(self) -> OutputKeys:
408
        """Output keys."""
409
        return OutputKeys.from_keys({"output"})
410

411

412
class LLMChatComponent(BaseLLMComponent):
413
    """LLM chat component."""
414

415
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
416
        """Validate component inputs during run_component."""
417
        if "messages" not in input:
418
            raise ValueError("Messages must be in input dict.")
419

420
        # if `messages` is a string, convert to a list of chat message
421
        if isinstance(input["messages"], get_args(StringableInput)):
422
            input["messages"] = validate_and_convert_stringable(input["messages"])
423
            input["messages"] = prompt_to_messages(str(input["messages"]))
424

425
        for message in input["messages"]:
426
            if not isinstance(message, ChatMessage):
427
                raise ValueError("Messages must be a list of ChatMessage")
428
        return input
429

430
    def _run_component(self, **kwargs: Any) -> Any:
431
        """Run component."""
432
        # TODO: support only complete for now
433
        # non-trivial to figure how to support chat/complete/etc.
434
        messages = kwargs["messages"]
435
        if self.streaming:
436
            response = self.llm.stream_chat(messages)
437
        else:
438
            response = self.llm.chat(messages)
439
        return {"output": response}
440

441
    async def _arun_component(self, **kwargs: Any) -> Any:
442
        """Run component."""
443
        # TODO: support only complete for now
444
        # non-trivial to figure how to support chat/complete/etc.
445
        messages = kwargs["messages"]
446
        if self.streaming:
447
            response = await self.llm.astream_chat(messages)
448
        else:
449
            response = await self.llm.achat(messages)
450
        return {"output": response}
451

452
    @property
453
    def input_keys(self) -> InputKeys:
454
        """Input keys."""
455
        # TODO: support only complete for now
456
        return InputKeys.from_keys({"messages"})
457

458
    @property
459
    def output_keys(self) -> OutputKeys:
460
        """Output keys."""
461
        return OutputKeys.from_keys({"output"})
462

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.