llama-index
461 строка · 15.1 Кб
1from collections import ChainMap
2from typing import (
3Any,
4Dict,
5List,
6Optional,
7Protocol,
8Sequence,
9get_args,
10runtime_checkable,
11)
12
13from llama_index.legacy.bridge.pydantic import BaseModel, Field, validator
14from llama_index.legacy.callbacks import CBEventType, EventPayload
15from llama_index.legacy.core.llms.types import (
16ChatMessage,
17ChatResponseAsyncGen,
18ChatResponseGen,
19CompletionResponseAsyncGen,
20CompletionResponseGen,
21MessageRole,
22)
23from llama_index.legacy.core.query_pipeline.query_component import (
24InputKeys,
25OutputKeys,
26QueryComponent,
27StringableInput,
28validate_and_convert_stringable,
29)
30from llama_index.legacy.llms.base import BaseLLM
31from llama_index.legacy.llms.generic_utils import (
32messages_to_prompt as generic_messages_to_prompt,
33)
34from llama_index.legacy.llms.generic_utils import (
35prompt_to_messages,
36)
37from llama_index.legacy.prompts import BasePromptTemplate, PromptTemplate
38from llama_index.legacy.types import (
39BaseOutputParser,
40PydanticProgramMode,
41TokenAsyncGen,
42TokenGen,
43)
44
45
46# NOTE: These two protocols are needed to appease mypy
47@runtime_checkable
48class MessagesToPromptType(Protocol):
49def __call__(self, messages: Sequence[ChatMessage]) -> str:
50pass
51
52
53@runtime_checkable
54class CompletionToPromptType(Protocol):
55def __call__(self, prompt: str) -> str:
56pass
57
58
59def stream_completion_response_to_tokens(
60completion_response_gen: CompletionResponseGen,
61) -> TokenGen:
62"""Convert a stream completion response to a stream of tokens."""
63
64def gen() -> TokenGen:
65for response in completion_response_gen:
66yield response.delta or ""
67
68return gen()
69
70
71def stream_chat_response_to_tokens(
72chat_response_gen: ChatResponseGen,
73) -> TokenGen:
74"""Convert a stream completion response to a stream of tokens."""
75
76def gen() -> TokenGen:
77for response in chat_response_gen:
78yield response.delta or ""
79
80return gen()
81
82
83async def astream_completion_response_to_tokens(
84completion_response_gen: CompletionResponseAsyncGen,
85) -> TokenAsyncGen:
86"""Convert a stream completion response to a stream of tokens."""
87
88async def gen() -> TokenAsyncGen:
89async for response in completion_response_gen:
90yield response.delta or ""
91
92return gen()
93
94
95async def astream_chat_response_to_tokens(
96chat_response_gen: ChatResponseAsyncGen,
97) -> TokenAsyncGen:
98"""Convert a stream completion response to a stream of tokens."""
99
100async def gen() -> TokenAsyncGen:
101async for response in chat_response_gen:
102yield response.delta or ""
103
104return gen()
105
106
107def default_completion_to_prompt(prompt: str) -> str:
108return prompt
109
110
111class LLM(BaseLLM):
112system_prompt: Optional[str] = Field(
113default=None, description="System prompt for LLM calls."
114)
115messages_to_prompt: MessagesToPromptType = Field(
116description="Function to convert a list of messages to an LLM prompt.",
117default=generic_messages_to_prompt,
118exclude=True,
119)
120completion_to_prompt: CompletionToPromptType = Field(
121description="Function to convert a completion to an LLM prompt.",
122default=default_completion_to_prompt,
123exclude=True,
124)
125output_parser: Optional[BaseOutputParser] = Field(
126description="Output parser to parse, validate, and correct errors programmatically.",
127default=None,
128exclude=True,
129)
130pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT
131
132# deprecated
133query_wrapper_prompt: Optional[BasePromptTemplate] = Field(
134description="Query wrapper prompt for LLM calls.",
135default=None,
136exclude=True,
137)
138
139@validator("messages_to_prompt", pre=True)
140def set_messages_to_prompt(
141cls, messages_to_prompt: Optional[MessagesToPromptType]
142) -> MessagesToPromptType:
143return messages_to_prompt or generic_messages_to_prompt
144
145@validator("completion_to_prompt", pre=True)
146def set_completion_to_prompt(
147cls, completion_to_prompt: Optional[CompletionToPromptType]
148) -> CompletionToPromptType:
149return completion_to_prompt or default_completion_to_prompt
150
151def _log_template_data(
152self, prompt: BasePromptTemplate, **prompt_args: Any
153) -> None:
154template_vars = {
155k: v
156for k, v in ChainMap(prompt.kwargs, prompt_args).items()
157if k in prompt.template_vars
158}
159with self.callback_manager.event(
160CBEventType.TEMPLATING,
161payload={
162EventPayload.TEMPLATE: prompt.get_template(llm=self),
163EventPayload.TEMPLATE_VARS: template_vars,
164EventPayload.SYSTEM_PROMPT: self.system_prompt,
165EventPayload.QUERY_WRAPPER_PROMPT: self.query_wrapper_prompt,
166},
167):
168pass
169
170def _get_prompt(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
171formatted_prompt = prompt.format(
172llm=self,
173messages_to_prompt=self.messages_to_prompt,
174completion_to_prompt=self.completion_to_prompt,
175**prompt_args,
176)
177if self.output_parser is not None:
178formatted_prompt = self.output_parser.format(formatted_prompt)
179return self._extend_prompt(formatted_prompt)
180
181def _get_messages(
182self, prompt: BasePromptTemplate, **prompt_args: Any
183) -> List[ChatMessage]:
184messages = prompt.format_messages(llm=self, **prompt_args)
185if self.output_parser is not None:
186messages = self.output_parser.format_messages(messages)
187return self._extend_messages(messages)
188
189def structured_predict(
190self,
191output_cls: BaseModel,
192prompt: PromptTemplate,
193**prompt_args: Any,
194) -> BaseModel:
195from llama_index.legacy.program.utils import get_program_for_llm
196
197program = get_program_for_llm(
198output_cls,
199prompt,
200self,
201pydantic_program_mode=self.pydantic_program_mode,
202)
203
204return program(**prompt_args)
205
206async def astructured_predict(
207self,
208output_cls: BaseModel,
209prompt: PromptTemplate,
210**prompt_args: Any,
211) -> BaseModel:
212from llama_index.legacy.program.utils import get_program_for_llm
213
214program = get_program_for_llm(
215output_cls,
216prompt,
217self,
218pydantic_program_mode=self.pydantic_program_mode,
219)
220
221return await program.acall(**prompt_args)
222
223def _parse_output(self, output: str) -> str:
224if self.output_parser is not None:
225return str(self.output_parser.parse(output))
226
227return output
228
229def predict(
230self,
231prompt: BasePromptTemplate,
232**prompt_args: Any,
233) -> str:
234"""Predict."""
235self._log_template_data(prompt, **prompt_args)
236
237if self.metadata.is_chat_model:
238messages = self._get_messages(prompt, **prompt_args)
239chat_response = self.chat(messages)
240output = chat_response.message.content or ""
241else:
242formatted_prompt = self._get_prompt(prompt, **prompt_args)
243response = self.complete(formatted_prompt, formatted=True)
244output = response.text
245
246return self._parse_output(output)
247
248def stream(
249self,
250prompt: BasePromptTemplate,
251**prompt_args: Any,
252) -> TokenGen:
253"""Stream."""
254self._log_template_data(prompt, **prompt_args)
255
256if self.metadata.is_chat_model:
257messages = self._get_messages(prompt, **prompt_args)
258chat_response = self.stream_chat(messages)
259stream_tokens = stream_chat_response_to_tokens(chat_response)
260else:
261formatted_prompt = self._get_prompt(prompt, **prompt_args)
262stream_response = self.stream_complete(formatted_prompt, formatted=True)
263stream_tokens = stream_completion_response_to_tokens(stream_response)
264
265if prompt.output_parser is not None or self.output_parser is not None:
266raise NotImplementedError("Output parser is not supported for streaming.")
267
268return stream_tokens
269
270async def apredict(
271self,
272prompt: BasePromptTemplate,
273**prompt_args: Any,
274) -> str:
275"""Async predict."""
276self._log_template_data(prompt, **prompt_args)
277
278if self.metadata.is_chat_model:
279messages = self._get_messages(prompt, **prompt_args)
280chat_response = await self.achat(messages)
281output = chat_response.message.content or ""
282else:
283formatted_prompt = self._get_prompt(prompt, **prompt_args)
284response = await self.acomplete(formatted_prompt, formatted=True)
285output = response.text
286
287return self._parse_output(output)
288
289async def astream(
290self,
291prompt: BasePromptTemplate,
292**prompt_args: Any,
293) -> TokenAsyncGen:
294"""Async stream."""
295self._log_template_data(prompt, **prompt_args)
296
297if self.metadata.is_chat_model:
298messages = self._get_messages(prompt, **prompt_args)
299chat_response = await self.astream_chat(messages)
300stream_tokens = await astream_chat_response_to_tokens(chat_response)
301else:
302formatted_prompt = self._get_prompt(prompt, **prompt_args)
303stream_response = await self.astream_complete(
304formatted_prompt, formatted=True
305)
306stream_tokens = await astream_completion_response_to_tokens(stream_response)
307
308if prompt.output_parser is not None or self.output_parser is not None:
309raise NotImplementedError("Output parser is not supported for streaming.")
310
311return stream_tokens
312
313def _extend_prompt(
314self,
315formatted_prompt: str,
316) -> str:
317"""Add system and query wrapper prompts to base prompt."""
318extended_prompt = formatted_prompt
319
320if self.system_prompt:
321extended_prompt = self.system_prompt + "\n\n" + extended_prompt
322
323if self.query_wrapper_prompt:
324extended_prompt = self.query_wrapper_prompt.format(
325query_str=extended_prompt
326)
327
328return extended_prompt
329
330def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
331"""Add system prompt to chat message list."""
332if self.system_prompt:
333messages = [
334ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt),
335*messages,
336]
337return messages
338
339def _as_query_component(self, **kwargs: Any) -> QueryComponent:
340"""Return query component."""
341if self.metadata.is_chat_model:
342return LLMChatComponent(llm=self, **kwargs)
343else:
344return LLMCompleteComponent(llm=self, **kwargs)
345
346
347class BaseLLMComponent(QueryComponent):
348"""Base LLM component."""
349
350llm: LLM = Field(..., description="LLM")
351streaming: bool = Field(default=False, description="Streaming mode")
352
353class Config:
354arbitrary_types_allowed = True
355
356def set_callback_manager(self, callback_manager: Any) -> None:
357"""Set callback manager."""
358self.llm.callback_manager = callback_manager
359
360
361class LLMCompleteComponent(BaseLLMComponent):
362"""LLM completion component."""
363
364def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
365"""Validate component inputs during run_component."""
366if "prompt" not in input:
367raise ValueError("Prompt must be in input dict.")
368
369# do special check to see if prompt is a list of chat messages
370if isinstance(input["prompt"], get_args(List[ChatMessage])):
371input["prompt"] = self.llm.messages_to_prompt(input["prompt"])
372input["prompt"] = validate_and_convert_stringable(input["prompt"])
373else:
374input["prompt"] = validate_and_convert_stringable(input["prompt"])
375input["prompt"] = self.llm.completion_to_prompt(input["prompt"])
376
377return input
378
379def _run_component(self, **kwargs: Any) -> Any:
380"""Run component."""
381# TODO: support only complete for now
382# non-trivial to figure how to support chat/complete/etc.
383prompt = kwargs["prompt"]
384# ignore all other kwargs for now
385if self.streaming:
386response = self.llm.stream_complete(prompt, formatted=True)
387else:
388response = self.llm.complete(prompt, formatted=True)
389return {"output": response}
390
391async def _arun_component(self, **kwargs: Any) -> Any:
392"""Run component."""
393# TODO: support only complete for now
394# non-trivial to figure how to support chat/complete/etc.
395prompt = kwargs["prompt"]
396# ignore all other kwargs for now
397response = await self.llm.acomplete(prompt, formatted=True)
398return {"output": response}
399
400@property
401def input_keys(self) -> InputKeys:
402"""Input keys."""
403# TODO: support only complete for now
404return InputKeys.from_keys({"prompt"})
405
406@property
407def output_keys(self) -> OutputKeys:
408"""Output keys."""
409return OutputKeys.from_keys({"output"})
410
411
412class LLMChatComponent(BaseLLMComponent):
413"""LLM chat component."""
414
415def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
416"""Validate component inputs during run_component."""
417if "messages" not in input:
418raise ValueError("Messages must be in input dict.")
419
420# if `messages` is a string, convert to a list of chat message
421if isinstance(input["messages"], get_args(StringableInput)):
422input["messages"] = validate_and_convert_stringable(input["messages"])
423input["messages"] = prompt_to_messages(str(input["messages"]))
424
425for message in input["messages"]:
426if not isinstance(message, ChatMessage):
427raise ValueError("Messages must be a list of ChatMessage")
428return input
429
430def _run_component(self, **kwargs: Any) -> Any:
431"""Run component."""
432# TODO: support only complete for now
433# non-trivial to figure how to support chat/complete/etc.
434messages = kwargs["messages"]
435if self.streaming:
436response = self.llm.stream_chat(messages)
437else:
438response = self.llm.chat(messages)
439return {"output": response}
440
441async def _arun_component(self, **kwargs: Any) -> Any:
442"""Run component."""
443# TODO: support only complete for now
444# non-trivial to figure how to support chat/complete/etc.
445messages = kwargs["messages"]
446if self.streaming:
447response = await self.llm.astream_chat(messages)
448else:
449response = await self.llm.achat(messages)
450return {"output": response}
451
452@property
453def input_keys(self) -> InputKeys:
454"""Input keys."""
455# TODO: support only complete for now
456return InputKeys.from_keys({"messages"})
457
458@property
459def output_keys(self) -> OutputKeys:
460"""Output keys."""
461return OutputKeys.from_keys({"output"})
462